Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
(
checkpoint_file: Union[str, os.PathLike],
tensor_parallel_split_mapping=None,
fliter_dict_keys=None,
device="cpu",
ckpt_quant_stage="O0",
quantization_linear_list=None,
quantization_config=None,
dtype=None,
return_numpy=False,
convert_from_hf=True,
transpose_weight_keys=None,
)
| 508 | |
| 509 | |
| 510 | def load_state_dict( |
| 511 | checkpoint_file: Union[str, os.PathLike], |
| 512 | tensor_parallel_split_mapping=None, |
| 513 | fliter_dict_keys=None, |
| 514 | device="cpu", |
| 515 | ckpt_quant_stage="O0", |
| 516 | quantization_linear_list=None, |
| 517 | quantization_config=None, |
| 518 | dtype=None, |
| 519 | return_numpy=False, |
| 520 | convert_from_hf=True, |
| 521 | transpose_weight_keys=None, |
| 522 | ): |
| 523 | """ |
| 524 | Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise. |
| 525 | """ |
| 526 | |
| 527 | if tensor_parallel_split_mapping is None: |
| 528 | tensor_parallel_split_mapping = {} |
| 529 | |
| 530 | if ( |
| 531 | checkpoint_file.endswith(".safetensors") or re.search(r"\.safetensors_shard_\d{4}$", checkpoint_file) |
| 532 | ) and is_safetensors_available(): |
| 533 | # Check format of the archive |
| 534 | with safe_open(checkpoint_file, framework="np") as f: |
| 535 | metadata = {"format": "np"} |
| 536 | |
| 537 | if metadata.get("format", "np") not in ["pd", "np"]: |
| 538 | raise OSError( |
| 539 | f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure " |
| 540 | "you save your model with the `save_pretrained` method." |
| 541 | ) |
| 542 | if metadata.get("format", "np") == "pd": |
| 543 | raise ValueError("Currently unsupport paddle weights file, use numpy instead.") |
| 544 | if metadata.get("format", "np") == "np": |
| 545 | thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1")) |
| 546 | if thread_num > 1: |
| 547 | logger.info(f"Set loading state_dict thread num to {thread_num}") |
| 548 | state_dict, scale_dict = {}, {} |
| 549 | if thread_num <= 1: |
| 550 | with safe_open(checkpoint_file, framework="np") as f: |
| 551 | state_dict, scale_dict = _load_part_state_dict( |
| 552 | list(f.keys()), |
| 553 | checkpoint_file, |
| 554 | tensor_parallel_split_mapping, |
| 555 | fliter_dict_keys, |
| 556 | device, |
| 557 | quantization_linear_list, |
| 558 | quantization_config, |
| 559 | dtype, |
| 560 | return_numpy, |
| 561 | convert_from_hf, |
| 562 | transpose_weight_keys, |
| 563 | ) |
| 564 | else: |
| 565 | # Load state dict in multi-thread to speed up loading |
| 566 | with safe_open(checkpoint_file, framework="np") as f: |
| 567 | keys_groups = _split_keys_evenly(list(f.keys()), thread_num) |
no test coverage detected