(self)
| 40 | self.state[key] = val.detach().to(device, copy=True) |
| 41 | |
| 42 | def update(self): |
| 43 | if self.unbias: |
| 44 | self.count = self.count * self.decay + 1 |
| 45 | w = 1 / self.count |
| 46 | else: |
| 47 | w = 1 - self.decay |
| 48 | for key, val in self.model.state_dict().items(): |
| 49 | if val.dtype != torch.float32: |
| 50 | continue |
| 51 | device = self.device or val.device |
| 52 | self.state[key].mul_(1 - w) |
| 53 | self.state[key].add_(val.detach().to(device), alpha=w) |
| 54 | |
| 55 | @contextmanager |
| 56 | def swap(self): |
no test coverage detected