Compute the decay factor for the exponential moving average.
(self, optimization_step: int)
| 683 | model.save_pretrained(path) |
| 684 | |
| 685 | def get_decay(self, optimization_step: int) -> float: |
| 686 | """ |
| 687 | Compute the decay factor for the exponential moving average. |
| 688 | """ |
| 689 | step = max(0, optimization_step - self.update_after_step - 1) |
| 690 | |
| 691 | if step <= 0: |
| 692 | return 0.0 |
| 693 | |
| 694 | if self.use_ema_warmup: |
| 695 | cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power |
| 696 | else: |
| 697 | cur_decay_value = (1 + step) / (10 + step) |
| 698 | |
| 699 | cur_decay_value = min(cur_decay_value, self.decay) |
| 700 | # make sure decay is not smaller than min_decay |
| 701 | cur_decay_value = max(cur_decay_value, self.min_decay) |
| 702 | return cur_decay_value |
| 703 | |
| 704 | @torch.no_grad() |
| 705 | def step(self, parameters: Iterable[torch.nn.Parameter]): |