MCPcopy Index your code
hub / github.com/geekcomputers/Python / AdamW

Class AdamW

ML/src/python/neuralforge/optim/optimizers.py:5–67  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

3import math
4
5class AdamW(Optimizer):
6 def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01, amsgrad=False):
7 if lr < 0.0:
8 raise ValueError(f"Invalid learning rate: {lr}")
9 if eps < 0.0:
10 raise ValueError(f"Invalid epsilon value: {eps}")
11 if not 0.0 <= betas[0] < 1.0:
12 raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
13 if not 0.0 <= betas[1] < 1.0:
14 raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
15
16 defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
17 super().__init__(params, defaults)
18
19 def step(self, closure=None):
20 loss = None
21 if closure is not None:
22 loss = closure()
23
24 for group in self.param_groups:
25 for p in group['params']:
26 if p.grad is None:
27 continue
28
29 grad = p.grad.data
30 if grad.is_sparse:
31 raise RuntimeError('AdamW does not support sparse gradients')
32
33 amsgrad = group['amsgrad']
34 state = self.state[p]
35
36 if len(state) == 0:
37 state['step'] = 0
38 state['exp_avg'] = torch.zeros_like(p.data)
39 state['exp_avg_sq'] = torch.zeros_like(p.data)
40 if amsgrad:
41 state['max_exp_avg_sq'] = torch.zeros_like(p.data)
42
43 exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
44 if amsgrad:
45 max_exp_avg_sq = state['max_exp_avg_sq']
46 beta1, beta2 = group['betas']
47
48 state['step'] += 1
49
50 p.data.mul_(1 - group['lr'] * group['weight_decay'])
51
52 exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
53 exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
54
55 if amsgrad:
56 torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
57 denom = max_exp_avg_sq.sqrt().add_(group['eps'])
58 else:
59 denom = exp_avg_sq.sqrt().add_(group['eps'])
60
61 bias_correction1 = 1 - beta1 ** state['step']
62 bias_correction2 = 1 - beta2 ** state['step']

Callers 3

mainFunction · 0.90
mainFunction · 0.90
mainFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected