MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / from_checkpoint

Method from_checkpoint

tensorrt_llm/models/modeling_utils.py:726–762  ·  view source on GitHub ↗
(
        cls,
        ckpt_dir: str,
        rank: Optional[int] = None,
        config: Optional[PretrainedConfig] = None,
        *,
        preprocess_weights_hook: Optional[Callable[[Dict[str, Tensor]],
                                                   Dict[str, Tensor]]] = None)

Source from the content-addressed store, hash-verified

724
725 @classmethod
726 def from_checkpoint(
727 cls,
728 ckpt_dir: str,
729 rank: Optional[int] = None,
730 config: Optional[PretrainedConfig] = None,
731 *,
732 preprocess_weights_hook: Optional[Callable[[Dict[str, Tensor]],
733 Dict[str, Tensor]]] = None):
734 if config is None:
735 config = PretrainedConfig.from_json_file(
736 os.path.join(ckpt_dir, 'config.json'))
737
738 if rank is not None:
739 config.set_rank(rank)
740
741 rank = config.mapping.rank
742 if config.mapping.cp_size > 1:
743 # cp_tp_pp rank -> tp_pp rank: because different cp ranks share the same ckpt.
744 cp_size = config.mapping.cp_size
745 # rank = pp_rank × tp_size × cp_size + tp_rank × cp_size + cp_rank.
746 # rank // cp_size is equivalent to pp_rank × tp_size + tp_rank.
747 rank = rank // cp_size
748 weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors')
749
750 assert os.path.isfile(weights_path)
751 weights = safetensors.torch.load_file(weights_path)
752 is_checkpoint_pruned = getattr(config, 'is_pruned', False)
753
754 if preprocess_weights_hook is not None:
755 weights = preprocess_weights_hook(weights)
756
757 weights = preprocess_weights(weights,
758 config,
759 from_pruned=is_checkpoint_pruned)
760 model = cls(config)
761 model.load(weights, from_pruned=is_checkpoint_pruned)
762 return model
763
764 def load(self, weights, from_pruned=False):
765 required_names = set()

Callers 8

build_from_hfFunction · 0.45
test_fp8_quantizationFunction · 0.45
_load_model_from_hfMethod · 0.45
_load_model_from_ckptMethod · 0.45
get_model_formatFunction · 0.45
build_modelFunction · 0.45
refit_engineFunction · 0.45

Calls 4

preprocess_weightsFunction · 0.85
from_json_fileMethod · 0.45
set_rankMethod · 0.45
loadMethod · 0.45

Tested by 2

test_fp8_quantizationFunction · 0.36