MCPcopy Index your code
hub / github.com/huggingface/diffusers / _get_expected_safetensors_files

Function _get_expected_safetensors_files

tests/testing_utils.py:1641–1722  ·  view source on GitHub ↗
(
        module: torch.nn.Module,
        offload_to_disk_path: str,
        offload_type: str,
        num_blocks_per_group: int | None = None,
        block_modules: list[str] | None = None,
        module_prefix: str = "",
    )

Source from the content-addressed store, hash-verified

1639 )
1640
1641 def _get_expected_safetensors_files(
1642 module: torch.nn.Module,
1643 offload_to_disk_path: str,
1644 offload_type: str,
1645 num_blocks_per_group: int | None = None,
1646 block_modules: list[str] | None = None,
1647 module_prefix: str = "",
1648 ) -> set[str]:
1649 expected_files = set()
1650
1651 def get_hashed_filename(group_id: str) -> str:
1652 short_hash = _compute_group_hash(group_id)
1653 return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
1654
1655 if offload_type == "block_level":
1656 if num_blocks_per_group is None:
1657 raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
1658
1659 block_modules_set = set(block_modules) if block_modules is not None else set()
1660
1661 modules_with_group_offloading = set()
1662 unmatched_modules = []
1663 for name, submodule in module.named_children():
1664 if name in block_modules_set:
1665 new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}."
1666 submodule_files = _get_expected_safetensors_files(
1667 submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix
1668 )
1669 expected_files.update(submodule_files)
1670 modules_with_group_offloading.add(name)
1671
1672 elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
1673 for i in range(0, len(submodule), num_blocks_per_group):
1674 current_modules = submodule[i : i + num_blocks_per_group]
1675 if not current_modules:
1676 continue
1677 group_id = f"{module_prefix}{name}_{i}_{i + len(current_modules) - 1}"
1678 expected_files.add(get_hashed_filename(group_id))
1679 for j in range(i, i + len(current_modules)):
1680 modules_with_group_offloading.add(f"{name}.{j}")
1681 else:
1682 unmatched_modules.append(submodule)
1683
1684 parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
1685 buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
1686
1687 if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0:
1688 expected_files.add(get_hashed_filename(f"{module_prefix}{module.__class__.__name__}_unmatched_group"))
1689
1690 elif offload_type == "leaf_level":
1691 # Handle leaf-level module groups
1692 for name, submodule in module.named_modules():
1693 if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
1694 # These groups will always have parameters, so a file is expected
1695 expected_files.add(get_hashed_filename(name))
1696
1697 # Handle groups for non-leaf parameters/buffers
1698 modules_with_group_offloading = {

Callers 1

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…