MCPcopy
hub / github.com/InternLM/lmdeploy / create_checkpoint

Function create_checkpoint

lmdeploy/turbomind/checkpoint.py:258–294  ·  view source on GitHub ↗

Pick the right :class:`Checkpoint` subclass for ``model_path``. Precedence matches the legacy ``create_loader``: 1. ``model.safetensors.index.json`` -> SafetensorsCheckpoint (sharded) 2. ``model*.safetensors`` -> SafetensorsCheckpoint (single-file or unsharded) 3. ``pyt

(model_path: str, *, mappings=())

Source from the content-addressed store, hash-verified

256
257
258def create_checkpoint(model_path: str, *, mappings=()) -> Checkpoint:
259 """Pick the right :class:`Checkpoint` subclass for ``model_path``.
260
261 Precedence matches the legacy ``create_loader``:
262
263 1. ``model.safetensors.index.json`` -> SafetensorsCheckpoint (sharded)
264 2. ``model*.safetensors`` -> SafetensorsCheckpoint (single-file or unsharded)
265 3. ``pytorch_model.bin.index.json`` -> PytorchCheckpoint
266 4. ``pytorch_model*.bin`` -> PytorchCheckpoint
267 5. ``*.safetensors`` -> SafetensorsCheckpoint (extra pattern)
268 6. ``*.pt`` / ``*.bin`` -> PytorchCheckpoint (extra pattern)
269 """
270 if osp.exists(osp.join(model_path, SAFE_WEIGHT_INDEX_NAME)):
271 return SafetensorsCheckpoint(
272 model_path, mappings=mappings,
273 index_name=SAFE_WEIGHT_INDEX_NAME)
274 if glob(osp.join(model_path, SAFE_WEIGHT_PATTERN)):
275 return SafetensorsCheckpoint(
276 model_path, mappings=mappings,
277 file_pattern=SAFE_WEIGHT_PATTERN)
278 if osp.exists(osp.join(model_path, WEIGHT_INDEX_NAME)):
279 return PytorchCheckpoint(
280 model_path, mappings=mappings,
281 index_name=WEIGHT_INDEX_NAME)
282 if glob(osp.join(model_path, WEIGHT_PATTERN)):
283 return PytorchCheckpoint(
284 model_path, mappings=mappings,
285 file_pattern=WEIGHT_PATTERN)
286 if glob(osp.join(model_path, EXTRA_SAFE_WEIGHT_PATTERN)):
287 return SafetensorsCheckpoint(
288 model_path, mappings=mappings,
289 file_pattern=EXTRA_SAFE_WEIGHT_PATTERN)
290 for p in EXTRA_WEIGHT_PATTERNS:
291 if glob(osp.join(model_path, p)):
292 return PytorchCheckpoint(
293 model_path, mappings=mappings, file_pattern=p)
294 raise RuntimeError(f'Failed to find valid checkpoint under {model_path!r}')

Callers 2

exportMethod · 0.85
export_iterMethod · 0.85

Calls 3

PytorchCheckpointClass · 0.85
joinMethod · 0.80

Tested by

no test coverage detected