MCPcopy
hub / github.com/hustvl/Vim / get_state_dict

Function get_state_dict

seg/mmcv_custom/checkpoint.py:521–562  ·  view source on GitHub ↗

Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. This method is modified from :meth:`torch.nn.Module.state_dict` to recursively check parallel

(module, destination=None, prefix='', keep_vars=False)

Source from the content-addressed store, hash-verified

519
520
521def get_state_dict(module, destination=None, prefix='', keep_vars=False):
522 """Returns a dictionary containing a whole state of the module.
523
524 Both parameters and persistent buffers (e.g. running averages) are
525 included. Keys are corresponding parameter and buffer names.
526
527 This method is modified from :meth:`torch.nn.Module.state_dict` to
528 recursively check parallel module in case that the model has a complicated
529 structure, e.g., nn.Module(nn.Module(DDP)).
530
531 Args:
532 module (nn.Module): The module to generate state_dict.
533 destination (OrderedDict): Returned dict for the state of the
534 module.
535 prefix (str): Prefix of the key.
536 keep_vars (bool): Whether to keep the variable property of the
537 parameters. Default: False.
538
539 Returns:
540 dict: A dictionary containing a whole state of the module.
541 """
542 # recursively check parallel module in case that the model has a
543 # complicated structure, e.g., nn.Module(nn.Module(DDP))
544 if is_module_wrapper(module):
545 module = module.module
546
547 # below is the same as torch.nn.Module.state_dict()
548 if destination is None:
549 destination = OrderedDict()
550 destination._metadata = OrderedDict()
551 destination._metadata[prefix[:-1]] = local_metadata = dict(
552 version=module._version)
553 _save_to_state_dict(module, destination, prefix, keep_vars)
554 for name, child in module._modules.items():
555 if child is not None:
556 get_state_dict(
557 child, destination, prefix + name + '.', keep_vars=keep_vars)
558 for hook in module._state_dict_hooks.values():
559 hook_result = hook(module, destination, prefix, local_metadata)
560 if hook_result is not None:
561 destination = hook_result
562 return destination
563
564
565def save_checkpoint(model, filename, optimizer=None, meta=None):

Callers 3

save_checkpointFunction · 0.90
mainFunction · 0.85
save_checkpointFunction · 0.85

Calls 2

_save_to_state_dictFunction · 0.85
hookFunction · 0.85

Tested by

no test coverage detected