Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. This method is modified from :meth:`torch.nn.Module.state_dict` to recursively check parallel
(module, destination=None, prefix='', keep_vars=False)
| 392 | |
| 393 | |
| 394 | def get_state_dict(module, destination=None, prefix='', keep_vars=False): |
| 395 | """Returns a dictionary containing a whole state of the module. |
| 396 | |
| 397 | Both parameters and persistent buffers (e.g. running averages) are |
| 398 | included. Keys are corresponding parameter and buffer names. |
| 399 | |
| 400 | This method is modified from :meth:`torch.nn.Module.state_dict` to |
| 401 | recursively check parallel module in case that the model has a complicated |
| 402 | structure, e.g., nn.Module(nn.Module(DDP)). |
| 403 | |
| 404 | Args: |
| 405 | module (nn.Module): The module to generate state_dict. |
| 406 | destination (OrderedDict): Returned dict for the state of the |
| 407 | module. |
| 408 | prefix (str): Prefix of the key. |
| 409 | keep_vars (bool): Whether to keep the variable property of the |
| 410 | parameters. Default: False. |
| 411 | |
| 412 | Returns: |
| 413 | dict: A dictionary containing a whole state of the module. |
| 414 | """ |
| 415 | # recursively check parallel module in case that the model has a |
| 416 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) |
| 417 | if is_module_wrapper(module): |
| 418 | module = module.module |
| 419 | |
| 420 | # below is the same as torch.nn.Module.state_dict() |
| 421 | if destination is None: |
| 422 | destination = OrderedDict() |
| 423 | destination._metadata = OrderedDict() |
| 424 | destination._metadata[prefix[:-1]] = local_metadata = dict( |
| 425 | version=module._version) |
| 426 | _save_to_state_dict(module, destination, prefix, keep_vars) |
| 427 | for name, child in module._modules.items(): |
| 428 | if child is not None: |
| 429 | get_state_dict( |
| 430 | child, destination, prefix + name + '.', keep_vars=keep_vars) |
| 431 | for hook in module._state_dict_hooks.values(): |
| 432 | hook_result = hook(module, destination, prefix, local_metadata) |
| 433 | if hook_result is not None: |
| 434 | destination = hook_result |
| 435 | return destination |
| 436 | |
| 437 | |
| 438 | def save_checkpoint(model, filename, optimizer=None, meta=None): |
no test coverage detected