Copy the master parameter data back into the model parameters.
(param_groups_and_shapes, master_params)
| 63 | |
| 64 | |
| 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): |
| 66 | """ |
| 67 | Copy the master parameter data back into the model parameters. |
| 68 | """ |
| 69 | # Without copying to a list, if a generator is passed, this will |
| 70 | # silently not copy any parameters. |
| 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): |
| 72 | for (_, param), unflat_master_param in zip( |
| 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) |
| 74 | ): |
| 75 | param.detach().copy_(unflat_master_param) |
| 76 | |
| 77 | |
| 78 | def unflatten_master_params(param_group, master_param): |
no test coverage detected