(self, zero_grad=False)
| 28 | |
| 29 | @torch.no_grad() |
| 30 | def second_step(self, zero_grad=False): |
| 31 | for group in self.param_groups: |
| 32 | for p in group["params"]: |
| 33 | if p.grad is None: continue |
| 34 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" |
| 35 | |
| 36 | self.base_optimizer.step() # do the actual "sharpness-aware" update |
| 37 | |
| 38 | if zero_grad: self.zero_grad() |
| 39 | |
| 40 | @torch.no_grad() |
| 41 | def step(self, closure=None): |