MCPcopy
hub / github.com/showlab/Show-o / load_state_dict

Method load_state_dict

models/training_utils.py:250–295  ·  view source on GitHub ↗

r""" Args: Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`.

(self, state_dict: dict)

Source from the content-addressed store, hash-verified

248 self.temp_stored_params = None
249
250 def load_state_dict(self, state_dict: dict) -> None:
251 r"""
252 Args:
253 Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
254 ema state dict.
255 state_dict (dict): EMA state. Should be an object returned
256 from a call to :meth:`state_dict`.
257 """
258 # deepcopy, to be consistent with module API
259 state_dict = copy.deepcopy(state_dict)
260
261 self.decay = state_dict.get("decay", self.decay)
262 if self.decay < 0.0 or self.decay > 1.0:
263 raise ValueError("Decay must be between 0 and 1")
264
265 self.min_decay = state_dict.get("min_decay", self.min_decay)
266 if not isinstance(self.min_decay, float):
267 raise ValueError("Invalid min_decay")
268
269 self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
270 if not isinstance(self.optimization_step, int):
271 raise ValueError("Invalid optimization_step")
272
273 self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
274 if not isinstance(self.update_after_step, int):
275 raise ValueError("Invalid update_after_step")
276
277 self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
278 if not isinstance(self.use_ema_warmup, bool):
279 raise ValueError("Invalid use_ema_warmup")
280
281 self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
282 if not isinstance(self.inv_gamma, (float, int)):
283 raise ValueError("Invalid inv_gamma")
284
285 self.power = state_dict.get("power", self.power)
286 if not isinstance(self.power, (float, int)):
287 raise ValueError("Invalid power")
288
289 shadow_params = state_dict.get("shadow_params", None)
290 if shadow_params is not None:
291 self.shadow_params = shadow_params
292 if not isinstance(self.shadow_params, list):
293 raise ValueError("shadow_params must be a list")
294 if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
295 raise ValueError("shadow_params must all be Tensors")
296
297
298# calculates entropy over each pixel distribution

Callers 13

inference_mmu.pyFile · 0.80
mainFunction · 0.80
mainFunction · 0.80
inference_t2i.pyFile · 0.80
mainFunction · 0.80
_video_vaeFunction · 0.80
inference_dpg.pyFile · 0.80
mainFunction · 0.80
mainFunction · 0.80

Calls

no outgoing calls

Tested by

no test coverage detected