Exponential Moving Average of models weights
| 569 | |
| 570 | # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 |
| 571 | class EMAModel: |
| 572 | """ |
| 573 | Exponential Moving Average of models weights |
| 574 | """ |
| 575 | |
| 576 | def __init__( |
| 577 | self, |
| 578 | parameters: Iterable[torch.nn.Parameter], |
| 579 | decay: float = 0.9999, |
| 580 | min_decay: float = 0.0, |
| 581 | update_after_step: int = 0, |
| 582 | use_ema_warmup: bool = False, |
| 583 | inv_gamma: float | int = 1.0, |
| 584 | power: float | int = 2 / 3, |
| 585 | foreach: bool = False, |
| 586 | model_cls: Any | None = None, |
| 587 | model_config: dict[str, Any] | None = None, |
| 588 | **kwargs, |
| 589 | ): |
| 590 | """ |
| 591 | Args: |
| 592 | parameters (Iterable[torch.nn.Parameter]): The parameters to track. |
| 593 | decay (float): The decay factor for the exponential moving average. |
| 594 | min_decay (float): The minimum decay factor for the exponential moving average. |
| 595 | update_after_step (int): The number of steps to wait before starting to update the EMA weights. |
| 596 | use_ema_warmup (bool): Whether to use EMA warmup. |
| 597 | inv_gamma (float): |
| 598 | Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True. |
| 599 | power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True. |
| 600 | foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster. |
| 601 | device (str | torch.device | None): The device to store the EMA weights on. If None, the EMA |
| 602 | weights will be stored on CPU. |
| 603 | |
| 604 | @crowsonkb's notes on EMA Warmup: |
| 605 | If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan |
| 606 | to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps), |
| 607 | gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 |
| 608 | at 215.4k steps). |
| 609 | """ |
| 610 | |
| 611 | if isinstance(parameters, torch.nn.Module): |
| 612 | deprecation_message = ( |
| 613 | "Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. " |
| 614 | "Please pass the parameters of the module instead." |
| 615 | ) |
| 616 | deprecate( |
| 617 | "passing a `torch.nn.Module` to `ExponentialMovingAverage`", |
| 618 | "1.0.0", |
| 619 | deprecation_message, |
| 620 | standard_warn=False, |
| 621 | ) |
| 622 | parameters = parameters.parameters() |
| 623 | |
| 624 | # set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility |
| 625 | use_ema_warmup = True |
| 626 | |
| 627 | if kwargs.get("max_value", None) is not None: |
| 628 | deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead." |
no outgoing calls
searching dependent graphs…