MCPcopy
hub / github.com/microsoft/Cream / save_checkpoint

Function save_checkpoint

EfficientViT/downstream/mmcv_custom/runner/checkpoint.py:19–85  ·  view source on GitHub ↗

Save checkpoint to file. The checkpoint will have 4 fields: ``meta``, ``state_dict`` and ``optimizer``, ``amp``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename.

(model, filename, optimizer=None, meta=None)

Source from the content-addressed store, hash-verified

17
18
19def save_checkpoint(model, filename, optimizer=None, meta=None):
20 """Save checkpoint to file.
21
22 The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
23 ``optimizer``, ``amp``. By default ``meta`` will contain version
24 and time info.
25
26 Args:
27 model (Module): Module whose params are to be saved.
28 filename (str): Checkpoint filename.
29 optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
30 meta (dict, optional): Metadata to be saved in checkpoint.
31 """
32 if meta is None:
33 meta = {}
34 elif not isinstance(meta, dict):
35 raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
36 meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
37
38 if is_module_wrapper(model):
39 model = model.module
40
41 if hasattr(model, 'CLASSES') and model.CLASSES is not None:
42 # save class name to the meta
43 meta.update(CLASSES=model.CLASSES)
44
45 checkpoint = {
46 'meta': meta,
47 'state_dict': weights_to_cpu(get_state_dict(model))
48 }
49 # save optimizer state dict in the checkpoint
50 if isinstance(optimizer, Optimizer):
51 checkpoint['optimizer'] = optimizer.state_dict()
52 elif isinstance(optimizer, dict):
53 checkpoint['optimizer'] = {}
54 for name, optim in optimizer.items():
55 checkpoint['optimizer'][name] = optim.state_dict()
56
57 # save amp state dict in the checkpoint
58 checkpoint['amp'] = apex.amp.state_dict()
59
60 if filename.startswith('pavi://'):
61 try:
62 from pavi import modelcloud
63 from pavi.exception import NodeNotFoundError
64 except ImportError:
65 raise ImportError(
66 'Please install pavi to load checkpoint from modelcloud.')
67 model_path = filename[7:]
68 root = modelcloud.Folder()
69 model_dir, model_name = osp.split(model_path)
70 try:
71 model = modelcloud.get(model_dir)
72 except NodeNotFoundError:
73 model = root.create_training_model(model_dir)
74 with TemporaryDirectory() as tmp_dir:
75 checkpoint_file = osp.join(tmp_dir, model_name)
76 with open(checkpoint_file, 'wb') as f:

Callers 1

save_checkpointMethod · 0.70

Calls 5

weights_to_cpuFunction · 0.90
get_state_dictFunction · 0.90
updateMethod · 0.45
state_dictMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected