(self, opt: th.optim.Optimizer)
| 187 | return self._optimize_normal(opt) |
| 188 | |
| 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): |
| 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) |
| 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) |
| 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) |
| 193 | if check_overflow(grad_norm): |
| 194 | self.lg_loss_scale -= 1 |
| 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") |
| 196 | zero_master_grads(self.master_params) |
| 197 | return False |
| 198 | |
| 199 | logger.logkv_mean("grad_norm", grad_norm) |
| 200 | logger.logkv_mean("param_norm", param_norm) |
| 201 | |
| 202 | for p in self.master_params: |
| 203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) |
| 204 | opt.step() |
| 205 | zero_master_grads(self.master_params) |
| 206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) |
| 207 | self.lg_loss_scale += self.fp16_scale_growth |
| 208 | return True |
| 209 | |
| 210 | def _optimize_normal(self, opt: th.optim.Optimizer): |
| 211 | grad_norm, param_norm = self._compute_norms() |
no test coverage detected