MCPcopy
hub / github.com/PaddlePaddle/PaddleFormers / from_pretrained

Method from_pretrained

paddleformers/peft/lora/lora_model.py:374–479  ·  view source on GitHub ↗
(cls, model, lora_path, **kwargs)

Source from the content-addressed store, hash-verified

372
373 @classmethod
374 def from_pretrained(cls, model, lora_path, **kwargs):
375 load_checkpoint_format = kwargs.pop("load_checkpoint_format", "flex_checkpoint")
376 load_via_cpu = kwargs.pop("load_via_cpu", False)
377 lora_config = kwargs.pop("lora_config", None)
378 # init lora config & lora model
379 if not isinstance(lora_config, LoRAConfig):
380 lora_config = LoRAConfig.from_pretrained(lora_path)
381 # define a new variable to conserve original lora_config.tensor_model_parallel_size value which will update while initializing lora model
382 lora_config_tensor_model_parallel_size = lora_config.tensor_model_parallel_size
383 lora_model = cls(model, lora_config)
384
385 lora_model_index_file = os.path.join(lora_path, SAFE_PEFT_WEIGHTS_INDEX_NAME)
386 if os.path.exists(lora_model_index_file):
387 # load safetensors format file.
388 expected_keys = set(lora_model.get_trainable_state_dict().keys())
389
390 if load_checkpoint_format == "flex_checkpoint":
391 lora_sharded_state_dict = lora_model.sharded_state_dict()
392 metadata_path = os.path.join(lora_path, FLEX_CKPT_AUTO_GENERATED_METADATA)
393
394 # delete the existing metadata file if it exists
395 try:
396 os.remove(metadata_path)
397 except FileNotFoundError:
398 pass
399 except Exception as e:
400 logger.error(f"Failed to delete {metadata_path}: {e}")
401
402 aoa_config = {"aoa_statements": []}
403 for key in lora_sharded_state_dict.keys():
404 if key not in expected_keys:
405 aoa_config["aoa_statements"].append(f"_ -> {key}")
406
407 dist.load_state_dict(
408 lora_sharded_state_dict,
409 path=lora_path,
410 aoa_config=aoa_config,
411 safetensors=True,
412 offload=load_via_cpu,
413 )
414
415 return lora_model
416
417 resolved_archieve_file, sharded_metadata = get_checkpoint_shard_files(
418 pretrained_model_name_or_path=lora_path,
419 index_filename=lora_model_index_file,
420 )
421 loaded_keys = sharded_metadata["all_checkpoint_keys"]
422 missing_keys = expected_keys - set(loaded_keys)
423 if len(missing_keys) > 0:
424 raise ValueError(f"missing_keys: {missing_keys}")
425
426 error_msgs = []
427 for shard_file in resolved_archieve_file:
428 pre_tensor_parallel_split = False
429 if model.config.tensor_model_parallel_size > 1:
430 pre_tensor_parallel_split = True
431 tp_actions = lora_model._get_tensor_parallel_convert_actions(loaded_keys, is_split=True)

Callers 15

setUpClassMethod · 0.45
setUpClassMethod · 0.45
setUpClassMethod · 0.45
_build_datasetMethod · 0.45
_build_datasetMethod · 0.45
_build_datasetMethod · 0.45
_build_datasetMethod · 0.45
_build_datasetMethod · 0.45
setUpClassMethod · 0.45

Calls 15

load_state_dictFunction · 0.85
_add_variantFunction · 0.85
removeMethod · 0.80
loadMethod · 0.80
infoMethod · 0.80
popMethod · 0.45
existsMethod · 0.45
keysMethod · 0.45
sharded_state_dictMethod · 0.45

Tested by 15

setUpClassMethod · 0.36
setUpClassMethod · 0.36
setUpClassMethod · 0.36
_build_datasetMethod · 0.36
_build_datasetMethod · 0.36
_build_datasetMethod · 0.36
_build_datasetMethod · 0.36
_build_datasetMethod · 0.36
setUpClassMethod · 0.36