MCPcopy
hub / github.com/huggingface/diffusers / _get_checkpoint_shard_files

Function _get_checkpoint_shard_files

src/diffusers/utils/hub_utils.py:359–467  ·  view source on GitHub ↗

For a given model: - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the Hub - returns the list of paths to all the shards, as well as some metadata. For the description of each arg, see [`PreTrainedModel.from_pre

(
    pretrained_model_name_or_path,
    index_filename,
    cache_dir=None,
    proxies=None,
    local_files_only=False,
    token=None,
    user_agent=None,
    revision=None,
    subfolder="",
    dduf_entries: dict[str, DDUFEntry] | None = None,
)

Source from the content-addressed store, hash-verified

357
358
359def _get_checkpoint_shard_files(
360 pretrained_model_name_or_path,
361 index_filename,
362 cache_dir=None,
363 proxies=None,
364 local_files_only=False,
365 token=None,
366 user_agent=None,
367 revision=None,
368 subfolder="",
369 dduf_entries: dict[str, DDUFEntry] | None = None,
370):
371 """
372 For a given model:
373
374 - download and cache all the shards of a sharded checkpoint if `pretrained_model_name_or_path` is a model ID on the
375 Hub
376 - returns the list of paths to all the shards, as well as some metadata.
377
378 For the description of each arg, see [`PreTrainedModel.from_pretrained`]. `index_filename` is the full path to the
379 index (downloaded and cached if `pretrained_model_name_or_path` is a model ID on the Hub).
380 """
381 if dduf_entries:
382 if index_filename not in dduf_entries:
383 raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
384 else:
385 if not os.path.isfile(index_filename):
386 raise ValueError(f"Can't find a checkpoint index ({index_filename}) in {pretrained_model_name_or_path}.")
387
388 if dduf_entries:
389 index = json.loads(dduf_entries[index_filename].read_text())
390 else:
391 with open(index_filename, "r") as f:
392 index = json.loads(f.read())
393
394 original_shard_filenames = sorted(set(index["weight_map"].values()))
395 sharded_metadata = index["metadata"]
396 sharded_metadata["all_checkpoint_keys"] = list(index["weight_map"].keys())
397 sharded_metadata["weight_map"] = index["weight_map"].copy()
398 shards_path = os.path.join(pretrained_model_name_or_path, subfolder)
399
400 # First, let's deal with local folder.
401 if os.path.isdir(pretrained_model_name_or_path) or dduf_entries:
402 shard_filenames = [os.path.join(shards_path, f) for f in original_shard_filenames]
403 for shard_file in shard_filenames:
404 if dduf_entries:
405 if shard_file not in dduf_entries:
406 raise FileNotFoundError(
407 f"{shards_path} does not appear to have a file named {shard_file} which is "
408 "required according to the checkpoint index."
409 )
410 else:
411 if not os.path.exists(shard_file):
412 raise FileNotFoundError(
413 f"{shards_path} does not appear to have a file named {shard_file} which is "
414 "required according to the checkpoint index."
415 )
416 return shard_filenames, sharded_metadata

Callers 1

from_pretrainedMethod · 0.85

Calls 1

existsMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…