(self, module)
| 14 | self.shadow[name] = param.data.clone() |
| 15 | |
| 16 | def update(self, module): |
| 17 | if isinstance(module, nn.DataParallel): |
| 18 | module = module.module |
| 19 | for name, param in module.named_parameters(): |
| 20 | if param.requires_grad: |
| 21 | self.shadow[name].data = ( |
| 22 | 1. - self.mu) * param.data + self.mu * self.shadow[name].data |
| 23 | |
| 24 | def ema(self, module): |
| 25 | if isinstance(module, nn.DataParallel): |
no outgoing calls
no test coverage detected