Update exponential moving average (ema) of model weights.
(cfg, num_gpus, model, model_ema, cur_epoch, cur_iter)
| 232 | return rt |
| 233 | |
| 234 | def update_model_ema(cfg, num_gpus, model, model_ema, cur_epoch, cur_iter): |
| 235 | """Update exponential moving average (ema) of model weights.""" |
| 236 | update_period = cfg.TRAIN.EMA_UPDATE_PERIOD |
| 237 | if update_period is None or update_period == 0 or cur_iter % update_period != 0: |
| 238 | return |
| 239 | # Adjust alpha to be fairly independent of other parameters |
| 240 | total_batch_size = num_gpus * cfg.DATA.BATCH_SIZE |
| 241 | adjust = total_batch_size / cfg.TRAIN.EPOCHS * update_period |
| 242 | # print('ema adjust', adjust) |
| 243 | alpha = min(1.0, cfg.TRAIN.EMA_ALPHA * adjust) |
| 244 | # During warmup simply copy over weights instead of using ema |
| 245 | alpha = 1.0 if cur_epoch < cfg.TRAIN.WARMUP_EPOCHS else alpha |
| 246 | # Take ema of all parameters (not just named parameters) |
| 247 | params = unwrap_model(model).state_dict() |
| 248 | for name, param in unwrap_model(model_ema).state_dict().items(): |
| 249 | param.copy_(param * (1.0 - alpha) + params[name] * alpha) |
no test coverage detected