(self)
| 226 | loss.backward() |
| 227 | |
| 228 | def optimize_fp16(self): |
| 229 | if any(not th.isfinite(p.grad).all() for p in self.model_params): |
| 230 | self.lg_loss_scale -= 1 |
| 231 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") |
| 232 | return |
| 233 | |
| 234 | model_grads_to_master_grads(self.model_params, self.master_params) |
| 235 | self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale)) |
| 236 | self._log_grad_norm() |
| 237 | self._anneal_lr() |
| 238 | self.opt.step() |
| 239 | for rate, params in zip(self.ema_rate, self.ema_params): |
| 240 | update_ema(params, self.master_params, rate=rate) |
| 241 | master_params_to_model_params(self.model_params, self.master_params) |
| 242 | self.lg_loss_scale += self.fp16_scale_growth |
| 243 | |
| 244 | def optimize_normal(self): |
| 245 | self._log_grad_norm() |
no test coverage detected