(self, zero_grad=False)
| 14 | |
| 15 | @torch.no_grad() |
| 16 | def first_step(self, zero_grad=False): |
| 17 | grad_norm = self._grad_norm() |
| 18 | for group in self.param_groups: |
| 19 | scale = group["rho"] / (grad_norm + 1e-12) |
| 20 | |
| 21 | for p in group["params"]: |
| 22 | if p.grad is None: continue |
| 23 | self.state[p]["old_p"] = p.data.clone() |
| 24 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) |
| 25 | p.add_(e_w) # climb to the local maximum "w + e(w)" |
| 26 | |
| 27 | if zero_grad: self.zero_grad() |
| 28 | |
| 29 | @torch.no_grad() |
| 30 | def second_step(self, zero_grad=False): |
no test coverage detected