(module, prefix='')
| 27 | state_dict._metadata = metadata |
| 28 | |
| 29 | def load(module, prefix=''): |
| 30 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
| 31 | |
| 32 | module._load_from_state_dict( |
| 33 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs) |
| 34 | for name, child in module._modules.items(): |
| 35 | if child is not None: |
| 36 | load(child, prefix + name + '.') |
| 37 | |
| 38 | load(model, prefix='' if hasattr(model, 'bert') else 'bert.') |
| 39 |