(self, model)
| 35 | p.requires_grad_(False) |
| 36 | |
| 37 | def update(self, model): |
| 38 | # Update EMA parameters |
| 39 | with torch.no_grad(): |
| 40 | self.updates += 1 |
| 41 | d = self.decay(self.updates) |
| 42 | |
| 43 | msd = model.module.state_dict() if dist_utils.is_parallel( |
| 44 | model) else model.state_dict() # model state_dict |
| 45 | for k, v in self.model.state_dict().items(): |
| 46 | if v.dtype.is_floating_point: |
| 47 | v *= d |
| 48 | v += (1. - d) * msd[k].detach() |
| 49 | |
| 50 | def update_attr(self, |
| 51 | model, |
no outgoing calls