| 159 | self.v = self._new_optim_param() |
| 160 | |
| 161 | def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]: |
| 162 | ret = [] |
| 163 | self.b1_t *= self.b1 |
| 164 | self.b2_t *= self.b2 |
| 165 | for i, (t, g) in enumerate(zip(params, grads)): |
| 166 | if g.device != self.m[i].device: g = g.contiguous().to(self.m[i].device) |
| 167 | self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype)) |
| 168 | self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype)) |
| 169 | m_hat = self.m[i] / (1.0 - self.b1_t) |
| 170 | v_hat = self.v[i] / (1.0 - self.b2_t) |
| 171 | up = (m_hat / (v_hat.sqrt() + self.eps)).shard_like(t) + self.wd * t.detach() |
| 172 | if not self.adam: |
| 173 | r1 = t.detach().square().sum().sqrt() |
| 174 | r2 = up.square().sum().sqrt() |
| 175 | r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0) |
| 176 | else: |
| 177 | r = 1.0 |
| 178 | ret.append((self.lr * r * up).cast(t.dtype)) |
| 179 | return ret, [self.b1_t, self.b2_t] + self.m + self.v |