MCPcopy
hub / github.com/huggingface/diffusers / load_state_dict

Method load_state_dict

src/diffusers/training_utils.py:857–903  ·  view source on GitHub ↗

r""" Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. Args: state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`.

(self, state_dict: dict)

Source from the content-addressed store, hash-verified

855 self.temp_stored_params = None
856
857 def load_state_dict(self, state_dict: dict) -> None:
858 r"""
859 Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
860 ema state dict.
861
862 Args:
863 state_dict (dict): EMA state. Should be an object returned
864 from a call to :meth:`state_dict`.
865 """
866 # deepcopy, to be consistent with module API
867 state_dict = copy.deepcopy(state_dict)
868
869 self.decay = state_dict.get("decay", self.decay)
870 if self.decay < 0.0 or self.decay > 1.0:
871 raise ValueError("Decay must be between 0 and 1")
872
873 self.min_decay = state_dict.get("min_decay", self.min_decay)
874 if not isinstance(self.min_decay, float):
875 raise ValueError("Invalid min_decay")
876
877 self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
878 if not isinstance(self.optimization_step, int):
879 raise ValueError("Invalid optimization_step")
880
881 self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
882 if not isinstance(self.update_after_step, int):
883 raise ValueError("Invalid update_after_step")
884
885 self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
886 if not isinstance(self.use_ema_warmup, bool):
887 raise ValueError("Invalid use_ema_warmup")
888
889 self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
890 if not isinstance(self.inv_gamma, (float, int)):
891 raise ValueError("Invalid inv_gamma")
892
893 self.power = state_dict.get("power", self.power)
894 if not isinstance(self.power, (float, int)):
895 raise ValueError("Invalid power")
896
897 shadow_params = state_dict.get("shadow_params", None)
898 if shadow_params is not None:
899 self.shadow_params = shadow_params
900 if not isinstance(self.shadow_params, list):
901 raise ValueError("shadow_params must be a list")
902 if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
903 raise ValueError("shadow_params must all be Tensors")

Calls 1

getMethod · 0.45