Update shadow weights with current model parameters.
(self, model)
| 82 | return self.decay |
| 83 | |
| 84 | def update(self, model): |
| 85 | """Update shadow weights with current model parameters.""" |
| 86 | decay = self._get_decay() |
| 87 | self._decay = decay |
| 88 | model_dict = model.state_dict() |
| 89 | for k, v in self.state_dict.items(): |
| 90 | if k not in self.ema_black_list and k in model_dict: |
| 91 | v = decay * v + (1 - decay) * model_dict[k].astype("float32") |
| 92 | v.stop_gradient = True |
| 93 | self.state_dict[k] = v |
| 94 | self.step += 1 |
| 95 | |
| 96 | def apply(self): |
| 97 | """Return bias-corrected EMA state dict for eval/save. |
no test coverage detected