r""" Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After validation (or model saving), use this to restore the f
(self, parameters: Iterable[torch.nn.Parameter])
| 830 | self.temp_stored_params = [param.detach().cpu().clone() for param in parameters] |
| 831 | |
| 832 | def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None: |
| 833 | r""" |
| 834 | Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters |
| 835 | without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After |
| 836 | validation (or model saving), use this to restore the former parameters. |
| 837 | |
| 838 | Args: |
| 839 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be |
| 840 | updated with the stored parameters. If `None`, the parameters with which this |
| 841 | `ExponentialMovingAverage` was initialized will be used. |
| 842 | """ |
| 843 | |
| 844 | if self.temp_stored_params is None: |
| 845 | raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") |
| 846 | if self.foreach: |
| 847 | torch._foreach_copy_( |
| 848 | [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params] |
| 849 | ) |
| 850 | else: |
| 851 | for c_param, param in zip(self.temp_stored_params, parameters): |
| 852 | param.data.copy_(c_param.data) |
| 853 | |
| 854 | # Better memory-wise. |
| 855 | self.temp_stored_params = None |
| 856 | |
| 857 | def load_state_dict(self, state_dict: dict) -> None: |
| 858 | r""" |