MCPcopy Index your code
hub / github.com/davda54/sam / first_step

Method first_step

sam.py:16–27  ·  view source on GitHub ↗
(self, zero_grad=False)

Source from the content-addressed store, hash-verified

14
15 @torch.no_grad()
16 def first_step(self, zero_grad=False):
17 grad_norm = self._grad_norm()
18 for group in self.param_groups:
19 scale = group["rho"] / (grad_norm + 1e-12)
20
21 for p in group["params"]:
22 if p.grad is None: continue
23 self.state[p]["old_p"] = p.data.clone()
24 e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
25 p.add_(e_w) # climb to the local maximum "w + e(w)"
26
27 if zero_grad: self.zero_grad()
28
29 @torch.no_grad()
30 def second_step(self, zero_grad=False):

Callers 2

stepMethod · 0.95
train.pyFile · 0.80

Calls 1

_grad_normMethod · 0.95

Tested by

no test coverage detected