Return bias-corrected EMA state dict for eval/save. Does NOT modify internal state.
(self)
| 94 | self.step += 1 |
| 95 | |
| 96 | def apply(self): |
| 97 | """Return bias-corrected EMA state dict for eval/save. |
| 98 | |
| 99 | Does NOT modify internal state. |
| 100 | """ |
| 101 | if self.step == 0: |
| 102 | return {k: v.clone() for k, v in self.state_dict.items()} |
| 103 | state = {} |
| 104 | for k, v in self.state_dict.items(): |
| 105 | if k in self.ema_black_list: |
| 106 | state[k] = v |
| 107 | else: |
| 108 | if self.ema_decay_type != "exponential": |
| 109 | # threshold / normal need bias-correction |
| 110 | v = v / (1 - self._decay**self.step) |
| 111 | v = v.clone() |
| 112 | v.stop_gradient = True |
| 113 | state[k] = v |
| 114 | return state |
| 115 | |
| 116 | def state_dict_for_save(self): |
| 117 | """Return serializable dict for checkpoint.""" |
no test coverage detected