MCPcopy Index your code
hub / github.com/huggingface/diffusers / step

Method step

src/diffusers/training_utils.py:705–760  ·  view source on GitHub ↗
(self, parameters: Iterable[torch.nn.Parameter])

Source from the content-addressed store, hash-verified

703
704 @torch.no_grad()
705 def step(self, parameters: Iterable[torch.nn.Parameter]):
706 if isinstance(parameters, torch.nn.Module):
707 deprecation_message = (
708 "Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
709 "Please pass the parameters of the module instead."
710 )
711 deprecate(
712 "passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
713 "1.0.0",
714 deprecation_message,
715 standard_warn=False,
716 )
717 parameters = parameters.parameters()
718
719 parameters = list(parameters)
720
721 self.optimization_step += 1
722
723 # Compute the decay factor for the exponential moving average.
724 decay = self.get_decay(self.optimization_step)
725 self.cur_decay_value = decay
726 one_minus_decay = 1 - decay
727
728 context_manager = contextlib.nullcontext()
729
730 if self.foreach:
731 if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
732 context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
733
734 with context_manager:
735 params_grad = [param for param in parameters if param.requires_grad]
736 s_params_grad = [
737 s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
738 ]
739
740 if len(params_grad) < len(parameters):
741 torch._foreach_copy_(
742 [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
743 [param for param in parameters if not param.requires_grad],
744 non_blocking=True,
745 )
746
747 torch._foreach_sub_(
748 s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
749 )
750
751 else:
752 for s_param, param in zip(self.shadow_params, parameters):
753 if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
754 context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
755
756 with context_manager:
757 if param.requires_grad:
758 s_param.sub_(one_minus_decay * (s_param - param))
759 else:
760 s_param.copy_(param)
761
762 def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:

Calls 4

get_decayMethod · 0.95
deprecateFunction · 0.85
parametersMethod · 0.80