MCPcopy Index your code
hub / github.com/davda54/sam / SAM

Class SAM

sam.py:4–63  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

2
3
4class SAM(torch.optim.Optimizer):
5 def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
6 assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"
7
8 defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
9 super(SAM, self).__init__(params, defaults)
10
11 self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
12 self.param_groups = self.base_optimizer.param_groups
13 self.defaults.update(self.base_optimizer.defaults)
14
15 @torch.no_grad()
16 def first_step(self, zero_grad=False):
17 grad_norm = self._grad_norm()
18 for group in self.param_groups:
19 scale = group["rho"] / (grad_norm + 1e-12)
20
21 for p in group["params"]:
22 if p.grad is None: continue
23 self.state[p]["old_p"] = p.data.clone()
24 e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
25 p.add_(e_w) # climb to the local maximum "w + e(w)"
26
27 if zero_grad: self.zero_grad()
28
29 @torch.no_grad()
30 def second_step(self, zero_grad=False):
31 for group in self.param_groups:
32 for p in group["params"]:
33 if p.grad is None: continue
34 p.data = self.state[p]["old_p"] # get back to "w" from "w + e(w)"
35
36 self.base_optimizer.step() # do the actual "sharpness-aware" update
37
38 if zero_grad: self.zero_grad()
39
40 @torch.no_grad()
41 def step(self, closure=None):
42 assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
43 closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
44
45 self.first_step(zero_grad=True)
46 closure()
47 self.second_step()
48
49 def _grad_norm(self):
50 shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
51 norm = torch.norm(
52 torch.stack([
53 ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
54 for group in self.param_groups for p in group["params"]
55 if p.grad is not None
56 ]),
57 p=2
58 )
59 return norm
60
61 def load_state_dict(self, state_dict):

Callers 1

train.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected