MCPcopy Index your code
hub / github.com/davda54/sam / _grad_norm

Method _grad_norm

sam.py:49–59  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

47 self.second_step()
48
49 def _grad_norm(self):
50 shared_device = self.param_groups[0]["params"][0].device # put everything on the same device, in case of model parallelism
51 norm = torch.norm(
52 torch.stack([
53 ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
54 for group in self.param_groups for p in group["params"]
55 if p.grad is not None
56 ]),
57 p=2
58 )
59 return norm
60
61 def load_state_dict(self, state_dict):
62 super().load_state_dict(state_dict)

Callers 1

first_stepMethod · 0.95

Calls

no outgoing calls

Tested by

no test coverage detected