r""" Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during checkpointing to save the ema state dict.
(self)
| 802 | ] |
| 803 | |
| 804 | def state_dict(self) -> dict: |
| 805 | r""" |
| 806 | Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during |
| 807 | checkpointing to save the ema state dict. |
| 808 | """ |
| 809 | # Following PyTorch conventions, references to tensors are returned: |
| 810 | # "returns a reference to the state and not its copy!" - |
| 811 | # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict |
| 812 | return { |
| 813 | "decay": self.decay, |
| 814 | "min_decay": self.min_decay, |
| 815 | "optimization_step": self.optimization_step, |
| 816 | "update_after_step": self.update_after_step, |
| 817 | "use_ema_warmup": self.use_ema_warmup, |
| 818 | "inv_gamma": self.inv_gamma, |
| 819 | "power": self.power, |
| 820 | "shadow_params": self.shadow_params, |
| 821 | } |
| 822 | |
| 823 | def store(self, parameters: Iterable[torch.nn.Parameter]) -> None: |
| 824 | r""" |
no outgoing calls