r""" Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`.
(self, state_dict: dict)
| 855 | self.temp_stored_params = None |
| 856 | |
| 857 | def load_state_dict(self, state_dict: dict) -> None: |
| 858 | r""" |
| 859 | Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the |
| 860 | ema state dict. |
| 861 | |
| 862 | Args: |
| 863 | state_dict (dict): EMA state. Should be an object returned |
| 864 | from a call to :meth:`state_dict`. |
| 865 | """ |
| 866 | # deepcopy, to be consistent with module API |
| 867 | state_dict = copy.deepcopy(state_dict) |
| 868 | |
| 869 | self.decay = state_dict.get("decay", self.decay) |
| 870 | if self.decay < 0.0 or self.decay > 1.0: |
| 871 | raise ValueError("Decay must be between 0 and 1") |
| 872 | |
| 873 | self.min_decay = state_dict.get("min_decay", self.min_decay) |
| 874 | if not isinstance(self.min_decay, float): |
| 875 | raise ValueError("Invalid min_decay") |
| 876 | |
| 877 | self.optimization_step = state_dict.get("optimization_step", self.optimization_step) |
| 878 | if not isinstance(self.optimization_step, int): |
| 879 | raise ValueError("Invalid optimization_step") |
| 880 | |
| 881 | self.update_after_step = state_dict.get("update_after_step", self.update_after_step) |
| 882 | if not isinstance(self.update_after_step, int): |
| 883 | raise ValueError("Invalid update_after_step") |
| 884 | |
| 885 | self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) |
| 886 | if not isinstance(self.use_ema_warmup, bool): |
| 887 | raise ValueError("Invalid use_ema_warmup") |
| 888 | |
| 889 | self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) |
| 890 | if not isinstance(self.inv_gamma, (float, int)): |
| 891 | raise ValueError("Invalid inv_gamma") |
| 892 | |
| 893 | self.power = state_dict.get("power", self.power) |
| 894 | if not isinstance(self.power, (float, int)): |
| 895 | raise ValueError("Invalid power") |
| 896 | |
| 897 | shadow_params = state_dict.get("shadow_params", None) |
| 898 | if shadow_params is not None: |
| 899 | self.shadow_params = shadow_params |
| 900 | if not isinstance(self.shadow_params, list): |
| 901 | raise ValueError("shadow_params must be a list") |
| 902 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): |
| 903 | raise ValueError("shadow_params must all be Tensors") |