Save checkpoint to file. The checkpoint will have 4 fields: ``meta``, ``state_dict`` and ``optimizer``, ``amp``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename.
(model, filename, optimizer=None, meta=None)
| 17 | |
| 18 | |
| 19 | def save_checkpoint(model, filename, optimizer=None, meta=None): |
| 20 | """Save checkpoint to file. |
| 21 | |
| 22 | The checkpoint will have 4 fields: ``meta``, ``state_dict`` and |
| 23 | ``optimizer``, ``amp``. By default ``meta`` will contain version |
| 24 | and time info. |
| 25 | |
| 26 | Args: |
| 27 | model (Module): Module whose params are to be saved. |
| 28 | filename (str): Checkpoint filename. |
| 29 | optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. |
| 30 | meta (dict, optional): Metadata to be saved in checkpoint. |
| 31 | """ |
| 32 | if meta is None: |
| 33 | meta = {} |
| 34 | elif not isinstance(meta, dict): |
| 35 | raise TypeError(f'meta must be a dict or None, but got {type(meta)}') |
| 36 | meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) |
| 37 | |
| 38 | if is_module_wrapper(model): |
| 39 | model = model.module |
| 40 | |
| 41 | if hasattr(model, 'CLASSES') and model.CLASSES is not None: |
| 42 | # save class name to the meta |
| 43 | meta.update(CLASSES=model.CLASSES) |
| 44 | |
| 45 | checkpoint = { |
| 46 | 'meta': meta, |
| 47 | 'state_dict': weights_to_cpu(get_state_dict(model)) |
| 48 | } |
| 49 | # save optimizer state dict in the checkpoint |
| 50 | if isinstance(optimizer, Optimizer): |
| 51 | checkpoint['optimizer'] = optimizer.state_dict() |
| 52 | elif isinstance(optimizer, dict): |
| 53 | checkpoint['optimizer'] = {} |
| 54 | for name, optim in optimizer.items(): |
| 55 | checkpoint['optimizer'][name] = optim.state_dict() |
| 56 | |
| 57 | # save amp state dict in the checkpoint |
| 58 | checkpoint['amp'] = apex.amp.state_dict() |
| 59 | |
| 60 | if filename.startswith('pavi://'): |
| 61 | try: |
| 62 | from pavi import modelcloud |
| 63 | from pavi.exception import NodeNotFoundError |
| 64 | except ImportError: |
| 65 | raise ImportError( |
| 66 | 'Please install pavi to load checkpoint from modelcloud.') |
| 67 | model_path = filename[7:] |
| 68 | root = modelcloud.Folder() |
| 69 | model_dir, model_name = osp.split(model_path) |
| 70 | try: |
| 71 | model = modelcloud.get(model_dir) |
| 72 | except NodeNotFoundError: |
| 73 | model = root.create_training_model(model_dir) |
| 74 | with TemporaryDirectory() as tmp_dir: |
| 75 | checkpoint_file = osp.join(tmp_dir, model_name) |
| 76 | with open(checkpoint_file, 'wb') as f: |
no test coverage detected