MCPcopy Index your code
hub / github.com/huggingface/diffusers / EMAModel

Class EMAModel

src/diffusers/training_utils.py:571–903  ·  view source on GitHub ↗

Exponential Moving Average of models weights

Source from the content-addressed store, hash-verified

569
570# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
571class EMAModel:
572 """
573 Exponential Moving Average of models weights
574 """
575
576 def __init__(
577 self,
578 parameters: Iterable[torch.nn.Parameter],
579 decay: float = 0.9999,
580 min_decay: float = 0.0,
581 update_after_step: int = 0,
582 use_ema_warmup: bool = False,
583 inv_gamma: float | int = 1.0,
584 power: float | int = 2 / 3,
585 foreach: bool = False,
586 model_cls: Any | None = None,
587 model_config: dict[str, Any] | None = None,
588 **kwargs,
589 ):
590 """
591 Args:
592 parameters (Iterable[torch.nn.Parameter]): The parameters to track.
593 decay (float): The decay factor for the exponential moving average.
594 min_decay (float): The minimum decay factor for the exponential moving average.
595 update_after_step (int): The number of steps to wait before starting to update the EMA weights.
596 use_ema_warmup (bool): Whether to use EMA warmup.
597 inv_gamma (float):
598 Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
599 power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
600 foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
601 device (str | torch.device | None): The device to store the EMA weights on. If None, the EMA
602 weights will be stored on CPU.
603
604 @crowsonkb's notes on EMA Warmup:
605 If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
606 to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
607 gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
608 at 215.4k steps).
609 """
610
611 if isinstance(parameters, torch.nn.Module):
612 deprecation_message = (
613 "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
614 "Please pass the parameters of the module instead."
615 )
616 deprecate(
617 "passing a `torch.nn.Module` to `ExponentialMovingAverage`",
618 "1.0.0",
619 deprecation_message,
620 standard_warn=False,
621 )
622 parameters = parameters.parameters()
623
624 # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
625 use_ema_warmup = True
626
627 if kwargs.get("max_value", None) is not None:
628 deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."

Callers 15

get_modelsMethod · 0.90
get_modelsMethod · 0.90
test_ema_trainingMethod · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by 4

get_modelsMethod · 0.72
get_modelsMethod · 0.72
test_ema_trainingMethod · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…