(
module: torch.nn.Module,
offload_to_disk_path: str,
offload_type: str,
num_blocks_per_group: int | None = None,
block_modules: list[str] | None = None,
module_prefix: str = "",
)
| 1639 | ) |
| 1640 | |
| 1641 | def _get_expected_safetensors_files( |
| 1642 | module: torch.nn.Module, |
| 1643 | offload_to_disk_path: str, |
| 1644 | offload_type: str, |
| 1645 | num_blocks_per_group: int | None = None, |
| 1646 | block_modules: list[str] | None = None, |
| 1647 | module_prefix: str = "", |
| 1648 | ) -> set[str]: |
| 1649 | expected_files = set() |
| 1650 | |
| 1651 | def get_hashed_filename(group_id: str) -> str: |
| 1652 | short_hash = _compute_group_hash(group_id) |
| 1653 | return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors") |
| 1654 | |
| 1655 | if offload_type == "block_level": |
| 1656 | if num_blocks_per_group is None: |
| 1657 | raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.") |
| 1658 | |
| 1659 | block_modules_set = set(block_modules) if block_modules is not None else set() |
| 1660 | |
| 1661 | modules_with_group_offloading = set() |
| 1662 | unmatched_modules = [] |
| 1663 | for name, submodule in module.named_children(): |
| 1664 | if name in block_modules_set: |
| 1665 | new_prefix = f"{module_prefix}{name}." if module_prefix else f"{name}." |
| 1666 | submodule_files = _get_expected_safetensors_files( |
| 1667 | submodule, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules, new_prefix |
| 1668 | ) |
| 1669 | expected_files.update(submodule_files) |
| 1670 | modules_with_group_offloading.add(name) |
| 1671 | |
| 1672 | elif isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)): |
| 1673 | for i in range(0, len(submodule), num_blocks_per_group): |
| 1674 | current_modules = submodule[i : i + num_blocks_per_group] |
| 1675 | if not current_modules: |
| 1676 | continue |
| 1677 | group_id = f"{module_prefix}{name}_{i}_{i + len(current_modules) - 1}" |
| 1678 | expected_files.add(get_hashed_filename(group_id)) |
| 1679 | for j in range(i, i + len(current_modules)): |
| 1680 | modules_with_group_offloading.add(f"{name}.{j}") |
| 1681 | else: |
| 1682 | unmatched_modules.append(submodule) |
| 1683 | |
| 1684 | parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading) |
| 1685 | buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading) |
| 1686 | |
| 1687 | if len(unmatched_modules) > 0 or len(parameters) > 0 or len(buffers) > 0: |
| 1688 | expected_files.add(get_hashed_filename(f"{module_prefix}{module.__class__.__name__}_unmatched_group")) |
| 1689 | |
| 1690 | elif offload_type == "leaf_level": |
| 1691 | # Handle leaf-level module groups |
| 1692 | for name, submodule in module.named_modules(): |
| 1693 | if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS): |
| 1694 | # These groups will always have parameters, so a file is expected |
| 1695 | expected_files.add(get_hashed_filename(name)) |
| 1696 | |
| 1697 | # Handle groups for non-leaf parameters/buffers |
| 1698 | modules_with_group_offloading = { |
no test coverage detected
searching dependent graphs…