MCPcopy Index your code
hub / github.com/tinygrad/tinygrad / LAMB

Class LAMB

tinygrad/nn/optim.py:147–179  ·  view source on GitHub ↗

LAMB optimizer with optional weight decay. - Paper: https://arxiv.org/abs/1904.00962

Source from the content-addressed store, hash-verified

145 return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, device=device, fused=fused)
146
147class LAMB(Optimizer):
148 """
149 LAMB optimizer with optional weight decay.
150
151 - Paper: https://arxiv.org/abs/1904.00962
152 """
153 def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, device=None, fused=FUSE_OPTIM):
154 if weight_decay < 0: raise ValueError(f"Invalid weight_decay value: {weight_decay}")
155 super().__init__(params, lr, device, fused)
156 self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam
157 self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device).is_param_(False) for _ in [b1, b2])
158 self.m = self._new_optim_param()
159 self.v = self._new_optim_param()
160
161 def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
162 ret = []
163 self.b1_t *= self.b1
164 self.b2_t *= self.b2
165 for i, (t, g) in enumerate(zip(params, grads)):
166 if g.device != self.m[i].device: g = g.contiguous().to(self.m[i].device)
167 self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
168 self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
169 m_hat = self.m[i] / (1.0 - self.b1_t)
170 v_hat = self.v[i] / (1.0 - self.b2_t)
171 up = (m_hat / (v_hat.sqrt() + self.eps)).shard_like(t) + self.wd * t.detach()
172 if not self.adam:
173 r1 = t.detach().square().sum().sqrt()
174 r2 = up.square().sum().sqrt()
175 r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
176 else:
177 r = 1.0
178 ret.append((self.lr * r * up).cast(t.dtype))
179 return ret, [self.b1_t, self.b2_t] + self.m + self.v

Callers 6

test_lamb_cpu_offloadMethod · 0.90
_test_layerMethod · 0.90
train_bertFunction · 0.90
AdamWFunction · 0.85
AdamFunction · 0.85

Calls

no outgoing calls

Tested by 2

test_lamb_cpu_offloadMethod · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…