@param model model to initialize the EMA with @param config EMAConfig object with configuration like ema_decay, ema_update_freq, ema_fp32 @param device If provided, copy EMA to this device (e.g. gpu). Otherwise EMA is in the same device as the model.
(self, model, ema_decay=0.9999, ema_fp32=False, device=None, skip_keys=None)
| 18 | """Exponential Moving Average of Fairseq Models""" |
| 19 | |
| 20 | def __init__(self, model, ema_decay=0.9999, ema_fp32=False, device=None, skip_keys=None): |
| 21 | """ |
| 22 | @param model model to initialize the EMA with |
| 23 | @param config EMAConfig object with configuration like |
| 24 | ema_decay, ema_update_freq, ema_fp32 |
| 25 | @param device If provided, copy EMA to this device (e.g. gpu). |
| 26 | Otherwise EMA is in the same device as the model. |
| 27 | """ |
| 28 | |
| 29 | self.decay = ema_decay |
| 30 | self.ema_fp32 = ema_fp32 |
| 31 | self.model = copy.deepcopy(model) |
| 32 | self.model.requires_grad_(False) |
| 33 | self.skip_keys = skip_keys or set() |
| 34 | self.fp32_params = {} |
| 35 | |
| 36 | if device is not None: |
| 37 | logging.info(f"Copying EMA model to device {device}") |
| 38 | self.model = self.model.to(device=device) |
| 39 | |
| 40 | if self.ema_fp32: |
| 41 | self.build_fp32_params() |
| 42 | |
| 43 | self.update_freq_counter = 0 |
| 44 | |
| 45 | def build_fp32_params(self, state_dict=None): |
| 46 | """ |
nothing calls this directly
no test coverage detected