Pick the right :class:`Checkpoint` subclass for ``model_path``. Precedence matches the legacy ``create_loader``: 1. ``model.safetensors.index.json`` -> SafetensorsCheckpoint (sharded) 2. ``model*.safetensors`` -> SafetensorsCheckpoint (single-file or unsharded) 3. ``pyt
(model_path: str, *, mappings=())
| 256 | |
| 257 | |
| 258 | def create_checkpoint(model_path: str, *, mappings=()) -> Checkpoint: |
| 259 | """Pick the right :class:`Checkpoint` subclass for ``model_path``. |
| 260 | |
| 261 | Precedence matches the legacy ``create_loader``: |
| 262 | |
| 263 | 1. ``model.safetensors.index.json`` -> SafetensorsCheckpoint (sharded) |
| 264 | 2. ``model*.safetensors`` -> SafetensorsCheckpoint (single-file or unsharded) |
| 265 | 3. ``pytorch_model.bin.index.json`` -> PytorchCheckpoint |
| 266 | 4. ``pytorch_model*.bin`` -> PytorchCheckpoint |
| 267 | 5. ``*.safetensors`` -> SafetensorsCheckpoint (extra pattern) |
| 268 | 6. ``*.pt`` / ``*.bin`` -> PytorchCheckpoint (extra pattern) |
| 269 | """ |
| 270 | if osp.exists(osp.join(model_path, SAFE_WEIGHT_INDEX_NAME)): |
| 271 | return SafetensorsCheckpoint( |
| 272 | model_path, mappings=mappings, |
| 273 | index_name=SAFE_WEIGHT_INDEX_NAME) |
| 274 | if glob(osp.join(model_path, SAFE_WEIGHT_PATTERN)): |
| 275 | return SafetensorsCheckpoint( |
| 276 | model_path, mappings=mappings, |
| 277 | file_pattern=SAFE_WEIGHT_PATTERN) |
| 278 | if osp.exists(osp.join(model_path, WEIGHT_INDEX_NAME)): |
| 279 | return PytorchCheckpoint( |
| 280 | model_path, mappings=mappings, |
| 281 | index_name=WEIGHT_INDEX_NAME) |
| 282 | if glob(osp.join(model_path, WEIGHT_PATTERN)): |
| 283 | return PytorchCheckpoint( |
| 284 | model_path, mappings=mappings, |
| 285 | file_pattern=WEIGHT_PATTERN) |
| 286 | if glob(osp.join(model_path, EXTRA_SAFE_WEIGHT_PATTERN)): |
| 287 | return SafetensorsCheckpoint( |
| 288 | model_path, mappings=mappings, |
| 289 | file_pattern=EXTRA_SAFE_WEIGHT_PATTERN) |
| 290 | for p in EXTRA_WEIGHT_PATTERNS: |
| 291 | if glob(osp.join(model_path, p)): |
| 292 | return PytorchCheckpoint( |
| 293 | model_path, mappings=mappings, file_pattern=p) |
| 294 | raise RuntimeError(f'Failed to find valid checkpoint under {model_path!r}') |
no test coverage detected