MCPcopy
hub / github.com/hpcaitech/ColossalAI / TorchAdamKernel

Class TorchAdamKernel

tests/test_optimizer/test_adam_kernel.py:43–62  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

41
42
43class TorchAdamKernel(AdamKernel):
44 def update(self, step: int, param: Tensor, grad: Tensor, exp_avg: Tensor, exp_avg_sq: Tensor):
45 bias_correction1 = 1 - self.beta1**step
46 bias_correction2 = 1 - self.beta2**step
47
48 if self.weight_decay != 0:
49 if self.use_adamw:
50 # Perform stepweight decay
51 param.mul_(1 - self.lr * self.weight_decay)
52 else:
53 grad = grad.add(param, alpha=self.weight_decay)
54
55 # Decay the first and second moment running average coefficient
56 exp_avg.mul_(self.beta1).add_(grad, alpha=1 - self.beta1)
57 exp_avg_sq.mul_(self.beta2).addcmul_(grad, grad, value=1 - self.beta2)
58 denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(self.eps)
59
60 step_size = self.lr / bias_correction1
61
62 param.addcdiv_(exp_avg, denom, value=-step_size)
63
64
65class FusedAdamKernel(AdamKernel):

Callers 1

check_adam_kernelFunction · 0.85

Calls

no outgoing calls

Tested by 1

check_adam_kernelFunction · 0.68

Used in the wild real call sites across dependent graphs

searching dependent graphs…