MCPcopy
hub / github.com/PaddlePaddle/PaddleFormers / load_state_dict

Function load_state_dict

paddleformers/transformers/model_utils.py:510–618  ·  view source on GitHub ↗

Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.

(
    checkpoint_file: Union[str, os.PathLike],
    tensor_parallel_split_mapping=None,
    fliter_dict_keys=None,
    device="cpu",
    ckpt_quant_stage="O0",
    quantization_linear_list=None,
    quantization_config=None,
    dtype=None,
    return_numpy=False,
    convert_from_hf=True,
    transpose_weight_keys=None,
)

Source from the content-addressed store, hash-verified

508
509
510def load_state_dict(
511 checkpoint_file: Union[str, os.PathLike],
512 tensor_parallel_split_mapping=None,
513 fliter_dict_keys=None,
514 device="cpu",
515 ckpt_quant_stage="O0",
516 quantization_linear_list=None,
517 quantization_config=None,
518 dtype=None,
519 return_numpy=False,
520 convert_from_hf=True,
521 transpose_weight_keys=None,
522):
523 """
524 Reads a PaddlePaddle checkpoint file, returning properly formatted errors if they arise.
525 """
526
527 if tensor_parallel_split_mapping is None:
528 tensor_parallel_split_mapping = {}
529
530 if (
531 checkpoint_file.endswith(".safetensors") or re.search(r"\.safetensors_shard_\d{4}$", checkpoint_file)
532 ) and is_safetensors_available():
533 # Check format of the archive
534 with safe_open(checkpoint_file, framework="np") as f:
535 metadata = {"format": "np"}
536
537 if metadata.get("format", "np") not in ["pd", "np"]:
538 raise OSError(
539 f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
540 "you save your model with the `save_pretrained` method."
541 )
542 if metadata.get("format", "np") == "pd":
543 raise ValueError("Currently unsupport paddle weights file, use numpy instead.")
544 if metadata.get("format", "np") == "np":
545 thread_num = int(os.environ.get("LOAD_STATE_DICT_THREAD_NUM", "1"))
546 if thread_num > 1:
547 logger.info(f"Set loading state_dict thread num to {thread_num}")
548 state_dict, scale_dict = {}, {}
549 if thread_num <= 1:
550 with safe_open(checkpoint_file, framework="np") as f:
551 state_dict, scale_dict = _load_part_state_dict(
552 list(f.keys()),
553 checkpoint_file,
554 tensor_parallel_split_mapping,
555 fliter_dict_keys,
556 device,
557 quantization_linear_list,
558 quantization_config,
559 dtype,
560 return_numpy,
561 convert_from_hf,
562 transpose_weight_keys,
563 )
564 else:
565 # Load state dict in multi-thread to speed up loading
566 with safe_open(checkpoint_file, framework="np") as f:
567 keys_groups = _split_keys_evenly(list(f.keys()), thread_num)

Callers 12

from_pretrainedMethod · 0.85
load_tp_checkpointFunction · 0.85
from_pretrainedMethod · 0.85
from_pretrainedMethod · 0.85

Calls 15

is_safetensors_availableFunction · 0.85
_load_part_state_dictFunction · 0.85
_split_keys_evenlyFunction · 0.85
fit_bf16_to_uint16_npFunction · 0.85
load_torchFunction · 0.85
paddleformers_loadFunction · 0.85
infoMethod · 0.80
device_guardFunction · 0.70
getMethod · 0.45
keysMethod · 0.45

Tested by

no test coverage detected