| 4 | |
| 5 | class LitEma(nn.Module): |
| 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): |
| 7 | super().__init__() |
| 8 | if decay < 0.0 or decay > 1.0: |
| 9 | raise ValueError("Decay must be between 0 and 1") |
| 10 | |
| 11 | self.m_name2s_name = {} |
| 12 | self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) |
| 13 | self.register_buffer( |
| 14 | "num_updates", |
| 15 | torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int), |
| 16 | ) |
| 17 | |
| 18 | for name, p in model.named_parameters(): |
| 19 | if p.requires_grad: |
| 20 | # remove as '.'-character is not allowed in buffers |
| 21 | s_name = name.replace(".", "") |
| 22 | self.m_name2s_name.update({name: s_name}) |
| 23 | self.register_buffer(s_name, p.clone().detach().data) |
| 24 | |
| 25 | self.collected_params = [] |
| 26 | |
| 27 | def reset_num_updates(self): |
| 28 | del self.num_updates |