MCPcopy
hub / github.com/alibaba/EasyCV / step

Method step

easycv/core/optimizer/lars.py:72–124  ·  view source on GitHub ↗

Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss.

(self, closure=None)

Source from the content-addressed store, hash-verified

70
71 @torch.no_grad()
72 def step(self, closure=None):
73 """Performs a single optimization step.
74
75 Arguments:
76 closure (callable, optional): A closure that reevaluates the model
77 and returns the loss.
78 """
79 loss = None
80 if closure is not None:
81 with torch.enable_grad():
82 loss = closure()
83
84 for group in self.param_groups:
85 weight_decay = group['weight_decay']
86 momentum = group['momentum']
87 dampening = group['dampening']
88 eta = group['eta']
89 nesterov = group['nesterov']
90 lr = group['lr']
91 lars_exclude = group.get('lars_exclude', False)
92
93 for p in group['params']:
94 if p.grad is None:
95 continue
96
97 d_p = p.grad
98
99 if lars_exclude:
100 local_lr = 1.
101 else:
102 weight_norm = torch.norm(p).item()
103 grad_norm = torch.norm(d_p).item()
104 # Compute local learning rate for this layer
105 local_lr = eta * weight_norm / \
106 (grad_norm + weight_decay * weight_norm)
107
108 actual_lr = local_lr * lr
109 d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr)
110 if momentum != 0:
111 param_state = self.state[p]
112 if 'momentum_buffer' not in param_state:
113 buf = param_state['momentum_buffer'] = \
114 torch.clone(d_p).detach()
115 else:
116 buf = param_state['momentum_buffer']
117 buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
118 if nesterov:
119 d_p = d_p.add(buf, alpha=momentum)
120 else:
121 d_p = buf
122 p.add_(-d_p)
123
124 return loss

Callers 14

after_train_iterMethod · 0.45
after_train_iterMethod · 0.45
mp_vs_ddp_mainFunction · 0.45
mp_mainFunction · 0.45
baseline_mainFunction · 0.45
test_no_parallelMethod · 0.45
test_mpMethod · 0.45
_run_baseline_train_mainFunction · 0.45
_run_mp_train_mainFunction · 0.45
_run_baseline_trainMethod · 0.45
_run_mp_1gpu_trainMethod · 0.45

Calls 4

normMethod · 0.80
getMethod · 0.45
addMethod · 0.45
cloneMethod · 0.45

Tested by 12

mp_vs_ddp_mainFunction · 0.36
mp_mainFunction · 0.36
baseline_mainFunction · 0.36
test_no_parallelMethod · 0.36
test_mpMethod · 0.36
_run_baseline_train_mainFunction · 0.36
_run_mp_train_mainFunction · 0.36
_run_baseline_trainMethod · 0.36
_run_mp_1gpu_trainMethod · 0.36
_test_state_dictMethod · 0.36