(self)
| 264 | return None |
| 265 | |
| 266 | def get_merged_model(self) -> PreTrainedModel: |
| 267 | # Guard against calling this method at the wrong time. |
| 268 | assert isinstance(self.model, PeftModel) |
| 269 | |
| 270 | # Check if we need special handling for quantized models |
| 271 | if self.settings.quantization == QuantizationMethod.BNB_4BIT: |
| 272 | # Quantized models need special handling - we must reload the base model |
| 273 | # in full precision to merge the LoRA adapters |
| 274 | |
| 275 | # Get the adapter state dict before we do anything |
| 276 | adapter_state = {} |
| 277 | for name, param in self.model.named_parameters(): |
| 278 | if "lora_" in name: |
| 279 | adapter_state[name] = param.data.clone().cpu() |
| 280 | |
| 281 | # Load base model in full precision on CPU to avoid VRAM issues |
| 282 | print("* Loading base model on CPU (this may take a while)...") |
| 283 | base_model = get_model_class(self.settings.model).from_pretrained( |
| 284 | self.settings.model, |
| 285 | torch_dtype=self.model.dtype, |
| 286 | device_map="cpu", |
| 287 | trust_remote_code=True |
| 288 | if self.settings.model in self.trusted_models |
| 289 | else None, |
| 290 | **self.revision_kwargs, |
| 291 | ) |
| 292 | |
| 293 | # Apply LoRA adapters to the CPU model |
| 294 | print("* Applying LoRA adapters...") |
| 295 | peft_model = get_peft_model(base_model, self.peft_config) |
| 296 | |
| 297 | # Copy the trained adapter weights |
| 298 | for name, param in peft_model.named_parameters(): |
| 299 | if name in adapter_state: |
| 300 | param.data = adapter_state[name].to(param.device) |
| 301 | |
| 302 | # Merge and unload |
| 303 | print("* Merging LoRA adapters into base model...") |
| 304 | merged_model = peft_model.merge_and_unload() |
| 305 | return merged_model |
| 306 | else: |
| 307 | # Non-quantized model - can merge directly |
| 308 | print("* Merging LoRA adapters into base model...") |
| 309 | merged_model = self.model.merge_and_unload() |
| 310 | # merge_and_unload() modifies self.model in-place, destroying LoRA adapters. |
| 311 | # Mark for full reload if user switches trials later. |
| 312 | self.needs_reload = True |
| 313 | return merged_model |
| 314 | |
| 315 | def reset_model(self): |
| 316 | """ |
no test coverage detected