(self)
| 47 | self.second_step() |
| 48 | |
| 49 | def _grad_norm(self): |
| 50 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism |
| 51 | norm = torch.norm( |
| 52 | torch.stack([ |
| 53 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) |
| 54 | for group in self.param_groups for p in group["params"] |
| 55 | if p.grad is not None |
| 56 | ]), |
| 57 | p=2 |
| 58 | ) |
| 59 | return norm |
| 60 | |
| 61 | def load_state_dict(self, state_dict): |
| 62 | super().load_state_dict(state_dict) |