| 372 | |
| 373 | @classmethod |
| 374 | def from_pretrained(cls, model, lora_path, **kwargs): |
| 375 | load_checkpoint_format = kwargs.pop("load_checkpoint_format", "flex_checkpoint") |
| 376 | load_via_cpu = kwargs.pop("load_via_cpu", False) |
| 377 | lora_config = kwargs.pop("lora_config", None) |
| 378 | # init lora config & lora model |
| 379 | if not isinstance(lora_config, LoRAConfig): |
| 380 | lora_config = LoRAConfig.from_pretrained(lora_path) |
| 381 | # define a new variable to conserve original lora_config.tensor_model_parallel_size value which will update while initializing lora model |
| 382 | lora_config_tensor_model_parallel_size = lora_config.tensor_model_parallel_size |
| 383 | lora_model = cls(model, lora_config) |
| 384 | |
| 385 | lora_model_index_file = os.path.join(lora_path, SAFE_PEFT_WEIGHTS_INDEX_NAME) |
| 386 | if os.path.exists(lora_model_index_file): |
| 387 | # load safetensors format file. |
| 388 | expected_keys = set(lora_model.get_trainable_state_dict().keys()) |
| 389 | |
| 390 | if load_checkpoint_format == "flex_checkpoint": |
| 391 | lora_sharded_state_dict = lora_model.sharded_state_dict() |
| 392 | metadata_path = os.path.join(lora_path, FLEX_CKPT_AUTO_GENERATED_METADATA) |
| 393 | |
| 394 | # delete the existing metadata file if it exists |
| 395 | try: |
| 396 | os.remove(metadata_path) |
| 397 | except FileNotFoundError: |
| 398 | pass |
| 399 | except Exception as e: |
| 400 | logger.error(f"Failed to delete {metadata_path}: {e}") |
| 401 | |
| 402 | aoa_config = {"aoa_statements": []} |
| 403 | for key in lora_sharded_state_dict.keys(): |
| 404 | if key not in expected_keys: |
| 405 | aoa_config["aoa_statements"].append(f"_ -> {key}") |
| 406 | |
| 407 | dist.load_state_dict( |
| 408 | lora_sharded_state_dict, |
| 409 | path=lora_path, |
| 410 | aoa_config=aoa_config, |
| 411 | safetensors=True, |
| 412 | offload=load_via_cpu, |
| 413 | ) |
| 414 | |
| 415 | return lora_model |
| 416 | |
| 417 | resolved_archieve_file, sharded_metadata = get_checkpoint_shard_files( |
| 418 | pretrained_model_name_or_path=lora_path, |
| 419 | index_filename=lora_model_index_file, |
| 420 | ) |
| 421 | loaded_keys = sharded_metadata["all_checkpoint_keys"] |
| 422 | missing_keys = expected_keys - set(loaded_keys) |
| 423 | if len(missing_keys) > 0: |
| 424 | raise ValueError(f"missing_keys: {missing_keys}") |
| 425 | |
| 426 | error_msgs = [] |
| 427 | for shard_file in resolved_archieve_file: |
| 428 | pre_tensor_parallel_split = False |
| 429 | if model.config.tensor_model_parallel_size > 1: |
| 430 | pre_tensor_parallel_split = True |
| 431 | tp_actions = lora_model._get_tensor_parallel_convert_actions(loaded_keys, is_split=True) |