MCPcopy
hub / github.com/p-e-w/heretic / get_merged_model

Method get_merged_model

src/heretic/model.py:266–313  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

264 return None
265
266 def get_merged_model(self) -> PreTrainedModel:
267 # Guard against calling this method at the wrong time.
268 assert isinstance(self.model, PeftModel)
269
270 # Check if we need special handling for quantized models
271 if self.settings.quantization == QuantizationMethod.BNB_4BIT:
272 # Quantized models need special handling - we must reload the base model
273 # in full precision to merge the LoRA adapters
274
275 # Get the adapter state dict before we do anything
276 adapter_state = {}
277 for name, param in self.model.named_parameters():
278 if "lora_" in name:
279 adapter_state[name] = param.data.clone().cpu()
280
281 # Load base model in full precision on CPU to avoid VRAM issues
282 print("* Loading base model on CPU (this may take a while)...")
283 base_model = get_model_class(self.settings.model).from_pretrained(
284 self.settings.model,
285 torch_dtype=self.model.dtype,
286 device_map="cpu",
287 trust_remote_code=True
288 if self.settings.model in self.trusted_models
289 else None,
290 **self.revision_kwargs,
291 )
292
293 # Apply LoRA adapters to the CPU model
294 print("* Applying LoRA adapters...")
295 peft_model = get_peft_model(base_model, self.peft_config)
296
297 # Copy the trained adapter weights
298 for name, param in peft_model.named_parameters():
299 if name in adapter_state:
300 param.data = adapter_state[name].to(param.device)
301
302 # Merge and unload
303 print("* Merging LoRA adapters into base model...")
304 merged_model = peft_model.merge_and_unload()
305 return merged_model
306 else:
307 # Non-quantized model - can merge directly
308 print("* Merging LoRA adapters into base model...")
309 merged_model = self.model.merge_and_unload()
310 # merge_and_unload() modifies self.model in-place, destroying LoRA adapters.
311 # Mark for full reload if user switches trials later.
312 self.needs_reload = True
313 return merged_model
314
315 def reset_model(self):
316 """

Callers 1

runFunction · 0.95

Calls 1

get_model_classFunction · 0.85

Tested by

no test coverage detected