(self, model)
| 54 | p.requires_grad_(False) |
| 55 | |
| 56 | def update(self, model): |
| 57 | # Update EMA parameters |
| 58 | with torch.no_grad(): |
| 59 | self.updates += 1 |
| 60 | d = self.decay(self.updates) |
| 61 | |
| 62 | msd = ( |
| 63 | model.module.state_dict() if is_parallel(model) else model.state_dict() |
| 64 | ) # model state_dict |
| 65 | for k, v in self.ema.state_dict().items(): |
| 66 | if v.dtype.is_floating_point: |
| 67 | v *= d |
| 68 | v += (1.0 - d) * msd[k].detach() |
| 69 | |
| 70 | def update_attr(self, model, include=(), exclude=("process_group", "reducer")): |
| 71 | # Update EMA attributes |
no test coverage detected