(module, prefix='')
| 69 | |
| 70 | # use _load_from_state_dict to enable checkpoint version control |
| 71 | def load(module, prefix=''): |
| 72 | # recursively check parallel module in case that the model has a |
| 73 | # complicated structure, e.g., nn.Module(nn.Module(DDP)) |
| 74 | if is_module_wrapper(module): |
| 75 | module = module.module |
| 76 | local_metadata = {} if metadata is None else metadata.get( |
| 77 | prefix[:-1], {}) |
| 78 | module._load_from_state_dict(state_dict, prefix, local_metadata, True, |
| 79 | all_missing_keys, unexpected_keys, |
| 80 | err_msg) |
| 81 | for name, child in module._modules.items(): |
| 82 | if child is not None: |
| 83 | load(child, prefix + name + '.') |
| 84 | |
| 85 | load(module) |
| 86 | load = None # break load->load reference cycle |
no test coverage detected