MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / load

Method load

tensorrt_llm/models/modeling_utils.py:764–801  ·  view source on GitHub ↗
(self, weights, from_pruned=False)

Source from the content-addressed store, hash-verified

762 return model
763
764 def load(self, weights, from_pruned=False):
765 required_names = set()
766 for name, param in self.named_parameters():
767 if param.is_inited():
768 continue
769 if name not in weights:
770 # Exemption for embedding sharing
771 if name.endswith('lm_head.weight') and any(
772 k.endswith('vocab_embedding.weight')
773 for k in weights.keys()):
774 continue
775 if name.endswith('lm_head.per_channel_scale') and any(
776 k.endswith('vocab_embedding.per_channel_scale')
777 for k in weights.keys()):
778 continue
779 required_names.add(name)
780
781 provided_names = set(weights.keys())
782
783 if not required_names.issubset(provided_names):
784 raise RuntimeError(
785 f"Required but not provided tensors:{required_names.difference(provided_names)}"
786 )
787 if not provided_names.issubset(required_names):
788 logger.warning(
789 f"Provided but not required tensors: {provided_names.difference(required_names)}"
790 )
791
792 for name, param in self.named_parameters():
793 if name in provided_names:
794 if not from_pruned:
795 try:
796 param.value = weights[name]
797 except Exception as e:
798 raise RuntimeError(
799 f"Encounter error '{e}' for parameter '{name}'")
800 else:
801 param.set_value_or_dummy(weights[name])
802
803 def save_checkpoint(self, output_dir, save_config=True):
804 # multiple ranks could share same config.json, so adding a save_config parameter to let user avoiding writing config.json in all ranks

Callers 15

_masked_compactionFunction · 0.45
_sum_bitmatrix_rowsFunction · 0.45
_reduce_groupedFunction · 0.45
_zero_masked_rowsFunction · 0.45
_matmul_ogsFunction · 0.45
_load_tile_attrsFunction · 0.45
_p_matmul_ogsFunction · 0.45
load_scaleFunction · 0.45
float_to_flexFunction · 0.45
_downcast_to_mxfpFunction · 0.45
_upcast_from_mxfpFunction · 0.45

Calls 6

is_initedMethod · 0.80
set_value_or_dummyMethod · 0.80
named_parametersMethod · 0.45
keysMethod · 0.45
addMethod · 0.45
warningMethod · 0.45

Tested by 3

test_end_to_endMethod · 0.36