MCPcopy Index your code
hub / github.com/modelscope/FunASR / _step_internal

Method _step_internal

funasr/models/data2vec/ema_module.py:93–126  ·  view source on GitHub ↗

One update of the EMA model based on new model weights

(self, new_model)

Source from the content-addressed store, hash-verified

91 return self.decay
92
93 def _step_internal(self, new_model):
94 """One update of the EMA model based on new model weights"""
95 decay = self.decay
96
97 ema_state_dict = {}
98 ema_params = self.fp32_params if self.ema_fp32 else self.model.state_dict()
99 for key, param in new_model.state_dict().items():
100 if isinstance(param, dict):
101 continue
102 try:
103 ema_param = ema_params[key]
104 except KeyError:
105 ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param)
106
107 if param.shape != ema_param.shape:
108 raise ValueError(
109 "incompatible tensor shapes between model param and ema param"
110 + "{} vs. {}".format(param.shape, ema_param.shape)
111 )
112
113 if "version" in key:
114 # Do not decay a model.version pytorch param
115 continue
116
117 if key in self.skip_keys or (
118 "num_batches_tracked" in key and ema_param.dtype == torch.int64
119 ):
120 ema_param = param.to(dtype=ema_param.dtype).clone()
121 ema_params[key].copy_(ema_param)
122 else:
123 ema_param.mul_(decay)
124 ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay)
125 ema_state_dict[key] = ema_param
126 self.restore(ema_state_dict, build_fp32_params=False)
127
128 def step(self, new_model):
129 """Step.

Callers 1

stepMethod · 0.95

Calls 2

restoreMethod · 0.95
state_dictMethod · 0.45

Tested by

no test coverage detected