MCPcopy
hub / github.com/tinygrad/tinygrad / _step

Method _step

tinygrad/nn/optim.py:161–179  ·  view source on GitHub ↗
(self, params:list[Tensor], grads:list[Tensor])

Source from the content-addressed store, hash-verified

159 self.v = self._new_optim_param()
160
161 def _step(self, params:list[Tensor], grads:list[Tensor]) -> tuple[list[Tensor], list[Tensor]]:
162 ret = []
163 self.b1_t *= self.b1
164 self.b2_t *= self.b2
165 for i, (t, g) in enumerate(zip(params, grads)):
166 if g.device != self.m[i].device: g = g.contiguous().to(self.m[i].device)
167 self.m[i].assign((self.b1 * self.m[i] + (1.0 - self.b1) * g).cast(self.m[i].dtype))
168 self.v[i].assign((self.b2 * self.v[i] + (1.0 - self.b2) * (g * g)).cast(self.v[i].dtype))
169 m_hat = self.m[i] / (1.0 - self.b1_t)
170 v_hat = self.v[i] / (1.0 - self.b2_t)
171 up = (m_hat / (v_hat.sqrt() + self.eps)).shard_like(t) + self.wd * t.detach()
172 if not self.adam:
173 r1 = t.detach().square().sum().sqrt()
174 r2 = up.square().sum().sqrt()
175 r: Tensor|float = Tensor.where(r1 > 0, Tensor.where(r2 > 0, r1 / r2, 1.0), 1.0)
176 else:
177 r = 1.0
178 ret.append((self.lr * r * up).cast(t.dtype))
179 return ret, [self.b1_t, self.b2_t] + self.m + self.v

Callers

nothing calls this directly

Calls 11

shard_likeMethod · 0.80
sqrtMethod · 0.80
detachMethod · 0.80
sumMethod · 0.80
squareMethod · 0.80
appendMethod · 0.80
toMethod · 0.45
contiguousMethod · 0.45
assignMethod · 0.45
castMethod · 0.45
whereMethod · 0.45

Tested by

no test coverage detected