MCPcopy
hub / github.com/microsoft/Cream / get_state_dict

Function get_state_dict

EfficientViT/downstream/mmcv_custom/checkpoint.py:394–435  ·  view source on GitHub ↗

Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. This method is modified from :meth:`torch.nn.Module.state_dict` to recursively check parallel

(module, destination=None, prefix='', keep_vars=False)

Source from the content-addressed store, hash-verified

392
393
394def get_state_dict(module, destination=None, prefix='', keep_vars=False):
395 """Returns a dictionary containing a whole state of the module.
396
397 Both parameters and persistent buffers (e.g. running averages) are
398 included. Keys are corresponding parameter and buffer names.
399
400 This method is modified from :meth:`torch.nn.Module.state_dict` to
401 recursively check parallel module in case that the model has a complicated
402 structure, e.g., nn.Module(nn.Module(DDP)).
403
404 Args:
405 module (nn.Module): The module to generate state_dict.
406 destination (OrderedDict): Returned dict for the state of the
407 module.
408 prefix (str): Prefix of the key.
409 keep_vars (bool): Whether to keep the variable property of the
410 parameters. Default: False.
411
412 Returns:
413 dict: A dictionary containing a whole state of the module.
414 """
415 # recursively check parallel module in case that the model has a
416 # complicated structure, e.g., nn.Module(nn.Module(DDP))
417 if is_module_wrapper(module):
418 module = module.module
419
420 # below is the same as torch.nn.Module.state_dict()
421 if destination is None:
422 destination = OrderedDict()
423 destination._metadata = OrderedDict()
424 destination._metadata[prefix[:-1]] = local_metadata = dict(
425 version=module._version)
426 _save_to_state_dict(module, destination, prefix, keep_vars)
427 for name, child in module._modules.items():
428 if child is not None:
429 get_state_dict(
430 child, destination, prefix + name + '.', keep_vars=keep_vars)
431 for hook in module._state_dict_hooks.values():
432 hook_result = hook(module, destination, prefix, local_metadata)
433 if hook_result is not None:
434 destination = hook_result
435 return destination
436
437
438def save_checkpoint(model, filename, optimizer=None, meta=None):

Callers 6

save_checkpointFunction · 0.90
mainFunction · 0.85
save_checkpointFunction · 0.85
mainFunction · 0.85
mainFunction · 0.85
train_one_epochFunction · 0.85

Calls 1

_save_to_state_dictFunction · 0.85

Tested by

no test coverage detected