(
module: torch.nn.Module,
offload_to_disk_path: str,
offload_type: str,
num_blocks_per_group: int | None = None,
block_modules: list[str] | None = None,
)
| 1722 | return expected_files |
| 1723 | |
| 1724 | def _check_safetensors_serialization( |
| 1725 | module: torch.nn.Module, |
| 1726 | offload_to_disk_path: str, |
| 1727 | offload_type: str, |
| 1728 | num_blocks_per_group: int | None = None, |
| 1729 | block_modules: list[str] | None = None, |
| 1730 | ) -> bool: |
| 1731 | if not os.path.isdir(offload_to_disk_path): |
| 1732 | return False, None, None |
| 1733 | |
| 1734 | expected_files = _get_expected_safetensors_files( |
| 1735 | module, offload_to_disk_path, offload_type, num_blocks_per_group, block_modules |
| 1736 | ) |
| 1737 | actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors"))) |
| 1738 | missing_files = expected_files - actual_files |
| 1739 | extra_files = actual_files - expected_files |
| 1740 | |
| 1741 | is_correct = not missing_files and not extra_files |
| 1742 | return is_correct, extra_files, missing_files |
| 1743 | |
| 1744 | |
| 1745 | class Expectations(DevicePropertiesUserDict): |
no test coverage detected
searching dependent graphs…