Args: model (nn.Module): model to apply EMA. decay (float): ema decay reate. updates (int): counter of EMA updates.
(self, model, decay=0.9999, updates=0)
| 39 | """ |
| 40 | |
| 41 | def __init__(self, model, decay=0.9999, updates=0): |
| 42 | """ |
| 43 | Args: |
| 44 | model (nn.Module): model to apply EMA. |
| 45 | decay (float): ema decay reate. |
| 46 | updates (int): counter of EMA updates. |
| 47 | """ |
| 48 | # Create EMA(FP32) |
| 49 | self.ema = deepcopy(model.module if is_parallel(model) else model).eval() |
| 50 | self.updates = updates |
| 51 | # decay exponential ramp (to help early epochs) |
| 52 | self.decay = lambda x: decay * (1 - math.exp(-x / 2000)) |
| 53 | for p in self.ema.parameters(): |
| 54 | p.requires_grad_(False) |
| 55 | |
| 56 | def update(self, model): |
| 57 | # Update EMA parameters |
nothing calls this directly
no test coverage detected