| 146 | |
| 147 | |
| 148 | class MixedPrecisionTrainer: |
| 149 | def __init__( |
| 150 | self, |
| 151 | *, |
| 152 | model, |
| 153 | use_fp16=False, |
| 154 | fp16_scale_growth=1e-3, |
| 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, |
| 156 | ): |
| 157 | self.model = model |
| 158 | self.use_fp16 = use_fp16 |
| 159 | self.fp16_scale_growth = fp16_scale_growth |
| 160 | |
| 161 | self.model_params = list(self.model.parameters()) |
| 162 | self.master_params = self.model_params |
| 163 | self.param_groups_and_shapes = None |
| 164 | self.lg_loss_scale = initial_lg_loss_scale |
| 165 | |
| 166 | if self.use_fp16: |
| 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( |
| 168 | self.model.named_parameters() |
| 169 | ) |
| 170 | self.master_params = make_master_params(self.param_groups_and_shapes) |
| 171 | self.model.convert_to_fp16() |
| 172 | |
| 173 | def zero_grad(self): |
| 174 | zero_grad(self.model_params) |
| 175 | |
| 176 | def backward(self, loss: th.Tensor): |
| 177 | if self.use_fp16: |
| 178 | loss_scale = 2 ** self.lg_loss_scale |
| 179 | (loss * loss_scale).backward() |
| 180 | else: |
| 181 | loss.backward() |
| 182 | |
| 183 | def optimize(self, opt: th.optim.Optimizer): |
| 184 | if self.use_fp16: |
| 185 | return self._optimize_fp16(opt) |
| 186 | else: |
| 187 | return self._optimize_normal(opt) |
| 188 | |
| 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): |
| 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) |
| 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) |
| 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) |
| 193 | if check_overflow(grad_norm): |
| 194 | self.lg_loss_scale -= 1 |
| 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") |
| 196 | zero_master_grads(self.master_params) |
| 197 | return False |
| 198 | |
| 199 | logger.logkv_mean("grad_norm", grad_norm) |
| 200 | logger.logkv_mean("param_norm", param_norm) |
| 201 | |
| 202 | for p in self.master_params: |
| 203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) |
| 204 | opt.step() |
| 205 | zero_master_grads(self.master_params) |