One update of the EMA model based on new model weights
(self, new_model)
| 91 | return self.decay |
| 92 | |
| 93 | def _step_internal(self, new_model): |
| 94 | """One update of the EMA model based on new model weights""" |
| 95 | decay = self.decay |
| 96 | |
| 97 | ema_state_dict = {} |
| 98 | ema_params = self.fp32_params if self.ema_fp32 else self.model.state_dict() |
| 99 | for key, param in new_model.state_dict().items(): |
| 100 | if isinstance(param, dict): |
| 101 | continue |
| 102 | try: |
| 103 | ema_param = ema_params[key] |
| 104 | except KeyError: |
| 105 | ema_param = param.float().clone() if param.ndim == 1 else copy.deepcopy(param) |
| 106 | |
| 107 | if param.shape != ema_param.shape: |
| 108 | raise ValueError( |
| 109 | "incompatible tensor shapes between model param and ema param" |
| 110 | + "{} vs. {}".format(param.shape, ema_param.shape) |
| 111 | ) |
| 112 | |
| 113 | if "version" in key: |
| 114 | # Do not decay a model.version pytorch param |
| 115 | continue |
| 116 | |
| 117 | if key in self.skip_keys or ( |
| 118 | "num_batches_tracked" in key and ema_param.dtype == torch.int64 |
| 119 | ): |
| 120 | ema_param = param.to(dtype=ema_param.dtype).clone() |
| 121 | ema_params[key].copy_(ema_param) |
| 122 | else: |
| 123 | ema_param.mul_(decay) |
| 124 | ema_param.add_(param.to(dtype=ema_param.dtype), alpha=1 - decay) |
| 125 | ema_state_dict[key] = ema_param |
| 126 | self.restore(ema_state_dict, build_fp32_params=False) |
| 127 | |
| 128 | def step(self, new_model): |
| 129 | """Step. |
no test coverage detected