Load weights from model file. `safetensors` or `pytorch binary` is supported. Args: file_path: model file path, ends with .bin or .safetensors. dtype: torch.dtype, data type. device: torch device like, optional. If None, load to cpu. Returns: Weights as
(
file_path: Union[str, Path],
dtype: Optional[torch.dtype] = None,
device: Optional[Union[str, torch.device]] = None,
)
| 154 | |
| 155 | |
| 156 | def load_state_dict( |
| 157 | file_path: Union[str, Path], |
| 158 | dtype: Optional[torch.dtype] = None, |
| 159 | device: Optional[Union[str, torch.device]] = None, |
| 160 | ) -> Dict[str, torch.Tensor]: |
| 161 | """ Load weights from model file. |
| 162 | |
| 163 | `safetensors` or `pytorch binary` is supported. |
| 164 | Args: |
| 165 | file_path: model file path, ends with .bin or .safetensors. |
| 166 | dtype: torch.dtype, data type. |
| 167 | device: torch device like, optional. If None, load to cpu. |
| 168 | Returns: |
| 169 | Weights as state dict. |
| 170 | """ |
| 171 | file_path = Path(file_path) |
| 172 | if dtype is not None: |
| 173 | assert isinstance(dtype, torch.dtype) |
| 174 | |
| 175 | if device is None: |
| 176 | device = 'cpu' |
| 177 | |
| 178 | model_params = {} |
| 179 | if file_path.suffix == '.safetensors': |
| 180 | # load from safetensors file |
| 181 | from safetensors import safe_open |
| 182 | with safe_open(file_path, framework='pt', device=device) as f: |
| 183 | for name in f.keys(): |
| 184 | tensor = f.get_tensor(name) |
| 185 | if dtype is not None: |
| 186 | tensor = tensor.to(dtype) |
| 187 | model_params[name] = tensor |
| 188 | elif file_path.suffix == '.bin': |
| 189 | # load from pytorch bin file |
| 190 | state_dict = torch.load(file_path, map_location=device) |
| 191 | for name in state_dict: |
| 192 | tensor = state_dict[name] |
| 193 | if dtype is not None: |
| 194 | tensor = tensor.to(dtype) |
| 195 | model_params[name] = tensor |
| 196 | else: |
| 197 | raise NotImplementedError( |
| 198 | f'Support .safetensors or .bin files, but got {str(file_path)}') |
| 199 | return model_params |
| 200 | |
| 201 | |
| 202 | def get_model_path( |
no test coverage detected