(self)
| 74 | self.state_dict[k] = paddle.zeros_like(v).astype("float32") |
| 75 | |
| 76 | def _get_decay(self): |
| 77 | if self.ema_decay_type == "threshold": |
| 78 | return min(self.decay, (1 + self.step) / (10 + self.step)) |
| 79 | elif self.ema_decay_type == "exponential": |
| 80 | return self.decay * (1 - math.exp(-(self.step + 1) / self.gamma)) |
| 81 | else: # normal |
| 82 | return self.decay |
| 83 | |
| 84 | def update(self, model): |
| 85 | """Update shadow weights with current model parameters.""" |