(self, model_path: str, *,
mappings=(),
index_name: str | None = None,
file_pattern: str | None = None)
| 227 | """torch.load-backed checkpoint over ``*.bin`` / ``*.pt`` shards.""" |
| 228 | |
| 229 | def __init__(self, model_path: str, *, |
| 230 | mappings=(), |
| 231 | index_name: str | None = None, |
| 232 | file_pattern: str | None = None): |
| 233 | self._mappings = list(mappings) |
| 234 | shards, _ = _gather_shards(model_path, index_name, file_pattern) |
| 235 | self._data: dict[str, torch.Tensor] = {} |
| 236 | for shard in shards: |
| 237 | tmp = torch.load(shard, map_location='cpu', weights_only=True) |
| 238 | for k, v in tmp.items(): |
| 239 | self._data[_apply_mappings(k, self._mappings)] = v |
| 240 | |
| 241 | def get(self, key: str, index=None): |
| 242 | t = self._data[key] |
nothing calls this directly
no test coverage detected