LAMB optimizer with optional weight decay. - Paper: https://arxiv.org/abs/1904.00962
| 145 | return LAMB(params, lr, b1, b2, eps, 0.0, adam=True, device=device, fused=fused) |
| 146 | |
| 147 | class LAMB(Optimizer): |
| 148 | """ |
| 149 | LAMB optimizer with optional weight decay. |
| 150 | |
| 151 | - Paper: https://arxiv.org/abs/1904.00962 |
| 152 | """ |
| 153 | def __init__(self, params: list[Tensor], lr=0.001, b1=0.9, b2=0.999, eps=1e-6, weight_decay=0.0, adam=False, device=None, fused=FUSE_OPTIM): |
| 154 | if weight_decay < 0: raise ValueError(f"Invalid weight_decay value: {weight_decay}") |
| 155 | super().__init__(params, lr, device, fused) |
| 156 | self.b1, self.b2, self.eps, self.wd, self.adam = b1, b2, eps, weight_decay, adam |
| 157 | self.b1_t, self.b2_t = (Tensor.ones((1,), dtype=dtypes.float32, device=self.device).is_param_(False) for _ in [b1, b2]) |
| 158 | self.m = self._new_optim_param() |
| 159 | self.v = self._new_optim_param() |
| 160 | |
| 161 | def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: |
| 162 | ret = [] |
| 163 | self.b1_t *= self.b1 |
| 164 | self.b2_t *= self.b2 |
| 165 | for i, (t, g) in enumerate(zip(params, grads)): |
| 166 | if g.device != self.m[i].device: g = g.contiguous().to(self.m[i].device) |
| 167 | self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype)) |
| 168 | self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype)) |
| 169 | m_hat = self.m[i] / (1.0 - self.b1_t) |
| 170 | v_hat = self.v[i] / (1.0 - self.b2_t) |
| 171 | up = (m_hat / (v_hat.sqrt() + self.eps)).shard_like(t) + self.wd * t.detach() |
| 172 | if not self.adam: |
| 173 | r1 = t.detach().square().sum().sqrt() |
| 174 | r2 = up.square().sum().sqrt() |
| 175 | r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0) |
| 176 | else: |
| 177 | r = 1.0 |
| 178 | ret.append((self.lr * r * up).cast(t.dtype)) |
| 179 | return ret, [self.b1_t, self.b2_t] + self.m + self.v |
no outgoing calls
searching dependent graphs…