| 703 | |
| 704 | @torch.no_grad() |
| 705 | def step(self, parameters: Iterable[torch.nn.Parameter]): |
| 706 | if isinstance(parameters, torch.nn.Module): |
| 707 | deprecation_message = ( |
| 708 | "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. " |
| 709 | "Please pass the parameters of the module instead." |
| 710 | ) |
| 711 | deprecate( |
| 712 | "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`", |
| 713 | "1.0.0", |
| 714 | deprecation_message, |
| 715 | standard_warn=False, |
| 716 | ) |
| 717 | parameters = parameters.parameters() |
| 718 | |
| 719 | parameters = list(parameters) |
| 720 | |
| 721 | self.optimization_step += 1 |
| 722 | |
| 723 | # Compute the decay factor for the exponential moving average. |
| 724 | decay = self.get_decay(self.optimization_step) |
| 725 | self.cur_decay_value = decay |
| 726 | one_minus_decay = 1 - decay |
| 727 | |
| 728 | context_manager = contextlib.nullcontext() |
| 729 | |
| 730 | if self.foreach: |
| 731 | if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): |
| 732 | context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None) |
| 733 | |
| 734 | with context_manager: |
| 735 | params_grad = [param for param in parameters if param.requires_grad] |
| 736 | s_params_grad = [ |
| 737 | s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad |
| 738 | ] |
| 739 | |
| 740 | if len(params_grad) < len(parameters): |
| 741 | torch._foreach_copy_( |
| 742 | [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad], |
| 743 | [param for param in parameters if not param.requires_grad], |
| 744 | non_blocking=True, |
| 745 | ) |
| 746 | |
| 747 | torch._foreach_sub_( |
| 748 | s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay |
| 749 | ) |
| 750 | |
| 751 | else: |
| 752 | for s_param, param in zip(self.shadow_params, parameters): |
| 753 | if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): |
| 754 | context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) |
| 755 | |
| 756 | with context_manager: |
| 757 | if param.requires_grad: |
| 758 | s_param.sub_(one_minus_decay * (s_param - param)) |
| 759 | else: |
| 760 | s_param.copy_(param) |
| 761 | |
| 762 | def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: |