r""" Args: Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the ema state dict. state_dict (dict): EMA state. Should be an object returned from a call to :meth:`state_dict`.
(self, state_dict: dict)
| 248 | self.temp_stored_params = None |
| 249 | |
| 250 | def load_state_dict(self, state_dict: dict) -> None: |
| 251 | r""" |
| 252 | Args: |
| 253 | Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the |
| 254 | ema state dict. |
| 255 | state_dict (dict): EMA state. Should be an object returned |
| 256 | from a call to :meth:`state_dict`. |
| 257 | """ |
| 258 | # deepcopy, to be consistent with module API |
| 259 | state_dict = copy.deepcopy(state_dict) |
| 260 | |
| 261 | self.decay = state_dict.get("decay", self.decay) |
| 262 | if self.decay < 0.0 or self.decay > 1.0: |
| 263 | raise ValueError("Decay must be between 0 and 1") |
| 264 | |
| 265 | self.min_decay = state_dict.get("min_decay", self.min_decay) |
| 266 | if not isinstance(self.min_decay, float): |
| 267 | raise ValueError("Invalid min_decay") |
| 268 | |
| 269 | self.optimization_step = state_dict.get("optimization_step", self.optimization_step) |
| 270 | if not isinstance(self.optimization_step, int): |
| 271 | raise ValueError("Invalid optimization_step") |
| 272 | |
| 273 | self.update_after_step = state_dict.get("update_after_step", self.update_after_step) |
| 274 | if not isinstance(self.update_after_step, int): |
| 275 | raise ValueError("Invalid update_after_step") |
| 276 | |
| 277 | self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup) |
| 278 | if not isinstance(self.use_ema_warmup, bool): |
| 279 | raise ValueError("Invalid use_ema_warmup") |
| 280 | |
| 281 | self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma) |
| 282 | if not isinstance(self.inv_gamma, (float, int)): |
| 283 | raise ValueError("Invalid inv_gamma") |
| 284 | |
| 285 | self.power = state_dict.get("power", self.power) |
| 286 | if not isinstance(self.power, (float, int)): |
| 287 | raise ValueError("Invalid power") |
| 288 | |
| 289 | shadow_params = state_dict.get("shadow_params", None) |
| 290 | if shadow_params is not None: |
| 291 | self.shadow_params = shadow_params |
| 292 | if not isinstance(self.shadow_params, list): |
| 293 | raise ValueError("shadow_params must be a list") |
| 294 | if not all(isinstance(p, torch.Tensor) for p in self.shadow_params): |
| 295 | raise ValueError("shadow_params must all be Tensors") |
| 296 | |
| 297 | |
| 298 | # calculates entropy over each pixel distribution |
no outgoing calls
no test coverage detected