Replace batch normalization statistics of the teacher model with that ot the student model
(model, ema_model)
| 79 | |
| 80 | |
| 81 | def update_bn(model, ema_model): |
| 82 | """ |
| 83 | Replace batch normalization statistics of the teacher model with that ot the student model |
| 84 | """ |
| 85 | for m2, m1 in zip(ema_model.named_modules(), model.named_modules()): |
| 86 | if ('bn' in m2[0]) and ('bn' in m1[0]): |
| 87 | bn2, bn1 = m2[1].state_dict(), m1[1].state_dict() |
| 88 | bn2['running_mean'].data.copy_(bn1['running_mean'].data) |
| 89 | bn2['running_var'].data.copy_(bn1['running_var'].data) |
| 90 | bn2['num_batches_tracked'].data.copy_(bn1['num_batches_tracked'].data) |