(
pretrained_model_name_or_path: str | Path,
*,
weights_name: str,
subfolder: str | None = None,
cache_dir: str | None = None,
force_download: bool = False,
proxies: dict | None = None,
local_files_only: bool = False,
token: str | None = None,
user_agent: dict | str | None = None,
revision: str | None = None,
commit_hash: str | None = None,
dduf_entries: dict[str, DDUFEntry] | None = None,
)
| 226 | |
| 227 | @validate_hf_hub_args |
| 228 | def _get_model_file( |
| 229 | pretrained_model_name_or_path: str | Path, |
| 230 | *, |
| 231 | weights_name: str, |
| 232 | subfolder: str | None = None, |
| 233 | cache_dir: str | None = None, |
| 234 | force_download: bool = False, |
| 235 | proxies: dict | None = None, |
| 236 | local_files_only: bool = False, |
| 237 | token: str | None = None, |
| 238 | user_agent: dict | str | None = None, |
| 239 | revision: str | None = None, |
| 240 | commit_hash: str | None = None, |
| 241 | dduf_entries: dict[str, DDUFEntry] | None = None, |
| 242 | ): |
| 243 | pretrained_model_name_or_path = str(pretrained_model_name_or_path) |
| 244 | |
| 245 | if dduf_entries: |
| 246 | if subfolder is not None: |
| 247 | raise ValueError( |
| 248 | "DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). " |
| 249 | "Please check the DDUF structure" |
| 250 | ) |
| 251 | model_file = ( |
| 252 | weights_name |
| 253 | if pretrained_model_name_or_path == "" |
| 254 | else "/".join([pretrained_model_name_or_path, weights_name]) |
| 255 | ) |
| 256 | if model_file in dduf_entries: |
| 257 | return model_file |
| 258 | else: |
| 259 | raise EnvironmentError(f"Error no file named {weights_name} found in archive {dduf_entries.keys()}.") |
| 260 | elif os.path.isfile(pretrained_model_name_or_path): |
| 261 | return pretrained_model_name_or_path |
| 262 | elif os.path.isdir(pretrained_model_name_or_path): |
| 263 | if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)): |
| 264 | # Load from a PyTorch checkpoint |
| 265 | model_file = os.path.join(pretrained_model_name_or_path, weights_name) |
| 266 | return model_file |
| 267 | elif subfolder is not None and os.path.isfile( |
| 268 | os.path.join(pretrained_model_name_or_path, subfolder, weights_name) |
| 269 | ): |
| 270 | model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name) |
| 271 | return model_file |
| 272 | else: |
| 273 | raise EnvironmentError( |
| 274 | f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}." |
| 275 | ) |
| 276 | else: |
| 277 | # 1. First check if deprecated way of loading from branches is used |
| 278 | if ( |
| 279 | revision in DEPRECATED_REVISION_ARGS |
| 280 | and (weights_name == WEIGHTS_NAME or weights_name == SAFETENSORS_WEIGHTS_NAME) |
| 281 | and version.parse(version.parse(__version__).base_version) >= version.parse("0.22.0") |
| 282 | ): |
| 283 | try: |
| 284 | model_file = hf_hub_download( |
| 285 | pretrained_model_name_or_path, |
no test coverage detected
searching dependent graphs…