(self, iter=None)
| 25 | b_ema.copy_(b) |
| 26 | |
| 27 | def update(self, iter=None): |
| 28 | if iter >= 0 and iter < self.start_iter: |
| 29 | decay = 0.0 |
| 30 | else: |
| 31 | decay = self.decay |
| 32 | |
| 33 | with torch.no_grad(): |
| 34 | for p_ema, p in zip(self.target.parameters(), self.source.parameters()): |
| 35 | p_ema.copy_(p.lerp(p_ema, decay)) |
| 36 | for (b_ema_name, b_ema), (b_name, b) in zip(self.target.named_buffers(), self.source.named_buffers()): |
| 37 | if "num_batches_tracked" in b_ema_name: |
| 38 | b_ema.copy_(b) |
| 39 | else: |
| 40 | b_ema.copy_(b.lerp(b_ema, decay)) |
| 41 | |
| 42 | |
| 43 | class EmaStylegan2(object): |
no outgoing calls
no test coverage detected