| 2 | |
| 3 | |
| 4 | class SAM(torch.optim.Optimizer): |
| 5 | def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs): |
| 6 | assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}" |
| 7 | |
| 8 | defaults = dict(rho=rho, adaptive=adaptive, **kwargs) |
| 9 | super(SAM, self).__init__(params, defaults) |
| 10 | |
| 11 | self.base_optimizer = base_optimizer(self.param_groups, **kwargs) |
| 12 | self.param_groups = self.base_optimizer.param_groups |
| 13 | self.defaults.update(self.base_optimizer.defaults) |
| 14 | |
| 15 | @torch.no_grad() |
| 16 | def first_step(self, zero_grad=False): |
| 17 | grad_norm = self._grad_norm() |
| 18 | for group in self.param_groups: |
| 19 | scale = group["rho"] / (grad_norm + 1e-12) |
| 20 | |
| 21 | for p in group["params"]: |
| 22 | if p.grad is None: continue |
| 23 | self.state[p]["old_p"] = p.data.clone() |
| 24 | e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p) |
| 25 | p.add_(e_w) # climb to the local maximum "w + e(w)" |
| 26 | |
| 27 | if zero_grad: self.zero_grad() |
| 28 | |
| 29 | @torch.no_grad() |
| 30 | def second_step(self, zero_grad=False): |
| 31 | for group in self.param_groups: |
| 32 | for p in group["params"]: |
| 33 | if p.grad is None: continue |
| 34 | p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)" |
| 35 | |
| 36 | self.base_optimizer.step() # do the actual "sharpness-aware" update |
| 37 | |
| 38 | if zero_grad: self.zero_grad() |
| 39 | |
| 40 | @torch.no_grad() |
| 41 | def step(self, closure=None): |
| 42 | assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided" |
| 43 | closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass |
| 44 | |
| 45 | self.first_step(zero_grad=True) |
| 46 | closure() |
| 47 | self.second_step() |
| 48 | |
| 49 | def _grad_norm(self): |
| 50 | shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism |
| 51 | norm = torch.norm( |
| 52 | torch.stack([ |
| 53 | ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device) |
| 54 | for group in self.param_groups for p in group["params"] |
| 55 | if p.grad is not None |
| 56 | ]), |
| 57 | p=2 |
| 58 | ) |
| 59 | return norm |
| 60 | |
| 61 | def load_state_dict(self, state_dict): |