Returns a dictionary containing a whole state of the module. Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. This method is modified from :meth:`torch.nn.Module.state_dict` to recursively check parallel
(module, destination=None, prefix='', keep_vars=False)
| 519 | |
| 520 | |
| 521 | def get_state_dict(module, destination=None, prefix='', keep_vars=False): |
| 522 | """Returns a dictionary containing a whole state of the module. |
| 523 | |
| 524 | Both parameters and persistent buffers (e.g. running averages) are |
| 525 | included. Keys are corresponding parameter and buffer names. |
| 526 | |
| 527 | This method is modified from :meth:`torch.nn.Module.state_dict` to |
| 528 | recursively check parallel module in case that the model has a complicated |
| 529 | structure, e.g., nn.Module(nn.Module(DDP)). |
| 530 | |
| 531 | Args: |
| 532 | module (nn.Module): The module to generate state_dict. |
| 533 | destination (OrderedDict): Returned dict for the state of the |
| 534 | module. |
| 535 | prefix (str): Prefix of the key. |
| 536 | keep_vars (bool): Whether to keep the variable property of the |
| 537 | parameters. Default: False. |
| 538 | |
| 539 | Returns: |
| 540 | dict: A dictionary containing a whole state of the module. |
| 541 | """ |
| 542 | # recursively check parallel module in case that the model has a |
| 543 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) |
| 544 | if is_module_wrapper(module): |
| 545 | module = module.module |
| 546 | |
| 547 | # below is the same as torch.nn.Module.state_dict() |
| 548 | if destination is None: |
| 549 | destination = OrderedDict() |
| 550 | destination._metadata = OrderedDict() |
| 551 | destination._metadata[prefix[:-1]] = local_metadata = dict( |
| 552 | version=module._version) |
| 553 | _save_to_state_dict(module, destination, prefix, keep_vars) |
| 554 | for name, child in module._modules.items(): |
| 555 | if child is not None: |
| 556 | get_state_dict( |
| 557 | child, destination, prefix + name + '.', keep_vars=keep_vars) |
| 558 | for hook in module._state_dict_hooks.values(): |
| 559 | hook_result = hook(module, destination, prefix, local_metadata) |
| 560 | if hook_result is not None: |
| 561 | destination = hook_result |
| 562 | return destination |
| 563 | |
| 564 | |
| 565 | def save_checkpoint(model, filename, optimizer=None, meta=None): |
no test coverage detected