For a given model: - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the Hub - returns the list of paths to all the shards, as well as some metadata. For the description of each arg, see [`PreTrainedModel.from_pre
(
pretrained_model_name_or_path,
index_filename,
cache_dir=None,
proxies=None,
local_files_only=False,
token=None,
user_agent=None,
revision=None,
subfolder="",
dduf_entries: dict[str, DDUFEntry] | None = None,
)
| 357 | |
| 358 | |
| 359 | def _get_checkpoint_shard_files( |
| 360 | pretrained_model_name_or_path, |
| 361 | index_filename, |
| 362 | cache_dir=None, |
| 363 | proxies=None, |
| 364 | local_files_only=False, |
| 365 | token=None, |
| 366 | user_agent=None, |
| 367 | revision=None, |
| 368 | subfolder="", |
| 369 | dduf_entries: dict[str, DDUFEntry] | None = None, |
| 370 | ): |
| 371 | """ |
| 372 | For a given model: |
| 373 | |
| 374 | - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the |
| 375 | Hub |
| 376 | - returns the list of paths to all the shards, as well as some metadata. |
| 377 | |
| 378 | For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the |
| 379 | index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub). |
| 380 | """ |
| 381 | if dduf_entries: |
| 382 | if index_filename not in dduf_entries: |
| 383 | raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") |
| 384 | else: |
| 385 | if not os.path.isfile(index_filename): |
| 386 | raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.") |
| 387 | |
| 388 | if dduf_entries: |
| 389 | index = json.loads(dduf_entries[index_filename].read_text()) |
| 390 | else: |
| 391 | with open(index_filename, "r") as f: |
| 392 | index = json.loads(f.read()) |
| 393 | |
| 394 | original_shard_filenames = sorted(set(index["weight_map"].values())) |
| 395 | sharded_metadata = index["metadata"] |
| 396 | sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys()) |
| 397 | sharded_metadata["weight_map"] = index["weight_map"].copy() |
| 398 | shards_path = os.path.join(pretrained_model_name_or_path, subfolder) |
| 399 | |
| 400 | # First, let's deal with local folder. |
| 401 | if os.path.isdir(pretrained_model_name_or_path) or dduf_entries: |
| 402 | shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames] |
| 403 | for shard_file in shard_filenames: |
| 404 | if dduf_entries: |
| 405 | if shard_file not in dduf_entries: |
| 406 | raise FileNotFoundError( |
| 407 | f"{shards_path} does not appear to have a file named {shard_file} which is " |
| 408 | "required according to the checkpoint index." |
| 409 | ) |
| 410 | else: |
| 411 | if not os.path.exists(shard_file): |
| 412 | raise FileNotFoundError( |
| 413 | f"{shards_path} does not appear to have a file named {shard_file} which is " |
| 414 | "required according to the checkpoint index." |
| 415 | ) |
| 416 | return shard_filenames, sharded_metadata |
no test coverage detected
searching dependent graphs…