(module, local_state_dict, prefix="")
| 31 | |
| 32 | ... |
| 33 | def load(module, local_state_dict, prefix=""): |
| 34 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {}) |
| 35 | if assign: |
| 36 | local_metadata["assign_to_params_buffers"] = assign |
| 37 | module._load_from_state_dict( |
| 38 | local_state_dict, |
| 39 | prefix, |
| 40 | local_metadata, |
| 41 | True, |
| 42 | missing_keys, |
| 43 | unexpected_keys, |
| 44 | error_msgs, |
| 45 | ) |
| 46 | for name, child in module._modules.items(): |
| 47 | if child is not None: |
| 48 | child_prefix = prefix + name + "." |
| 49 | child_state_dict = { |
| 50 | k: v |
| 51 | for k, v in local_state_dict.items() |
| 52 | if k.startswith(child_prefix) |
| 53 | } |
| 54 | load(child, child_state_dict, child_prefix) # noqa: F821 |
| 55 | |
| 56 | def __setattr__(self, name: str, value: Union[Tensor, "Module"]) -> None: |
| 57 | ... |
no test coverage detected