(file_path, torch_dtype=None)
| 75 | |
| 76 | |
| 77 | def load_state_dict(file_path, torch_dtype=None): |
| 78 | if file_path.endswith(".safetensors"): |
| 79 | return load_state_dict_from_safetensors(file_path, torch_dtype=torch_dtype) |
| 80 | else: |
| 81 | return load_state_dict_from_bin(file_path, torch_dtype=torch_dtype) |
| 82 | |
| 83 | |
| 84 | def load_state_dict_from_safetensors(file_path, torch_dtype=None): |
no test coverage detected