Step the EMA model towards the current model.
(ema_model, model, decay=0.9999)
| 189 | |
| 190 | @torch.no_grad() |
| 191 | def update_ema(ema_model, model, decay=0.9999): |
| 192 | """ |
| 193 | Step the EMA model towards the current model. |
| 194 | """ |
| 195 | ema_params = OrderedDict(ema_model.named_parameters()) |
| 196 | model_params = OrderedDict(model.named_parameters()) |
| 197 | |
| 198 | for name, param in model_params.items(): |
| 199 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed |
| 200 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) |
| 201 | |
| 202 | def requires_grad(model, flag=True): |
| 203 | """ |
no outgoing calls
no test coverage detected