MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / load_state_dict

Function load_state_dict

tensorrt_llm/models/convert_utils.py:156–199  ·  view source on GitHub ↗

Load weights from model file. `safetensors` or `pytorch binary` is supported. Args: file_path: model file path, ends with .bin or .safetensors. dtype: torch.dtype, data type. device: torch device like, optional. If None, load to cpu. Returns: Weights as

(
    file_path: Union[str, Path],
    dtype: Optional[torch.dtype] = None,
    device: Optional[Union[str, torch.device]] = None,
)

Source from the content-addressed store, hash-verified

154
155
156def load_state_dict(
157 file_path: Union[str, Path],
158 dtype: Optional[torch.dtype] = None,
159 device: Optional[Union[str, torch.device]] = None,
160) -> Dict[str, torch.Tensor]:
161 """ Load weights from model file.
162
163 `safetensors` or `pytorch binary` is supported.
164 Args:
165 file_path: model file path, ends with .bin or .safetensors.
166 dtype: torch.dtype, data type.
167 device: torch device like, optional. If None, load to cpu.
168 Returns:
169 Weights as state dict.
170 """
171 file_path = Path(file_path)
172 if dtype is not None:
173 assert isinstance(dtype, torch.dtype)
174
175 if device is None:
176 device = 'cpu'
177
178 model_params = {}
179 if file_path.suffix == '.safetensors':
180 # load from safetensors file
181 from safetensors import safe_open
182 with safe_open(file_path, framework='pt', device=device) as f:
183 for name in f.keys():
184 tensor = f.get_tensor(name)
185 if dtype is not None:
186 tensor = tensor.to(dtype)
187 model_params[name] = tensor
188 elif file_path.suffix == '.bin':
189 # load from pytorch bin file
190 state_dict = torch.load(file_path, map_location=device)
191 for name in state_dict:
192 tensor = state_dict[name]
193 if dtype is not None:
194 tensor = tensor.to(dtype)
195 model_params[name] = tensor
196 else:
197 raise NotImplementedError(
198 f'Support .safetensors or .bin files, but got {str(file_path)}')
199 return model_params
200
201
202def get_model_path(

Callers 8

convert_hf_modelFunction · 0.90
load_from_ckptMethod · 0.85
__init__Method · 0.85
load_from_model_dirMethod · 0.85

Calls 4

get_tensorMethod · 0.80
keysMethod · 0.45
toMethod · 0.45
loadMethod · 0.45

Tested by

no test coverage detected