MCPcopy
hub / github.com/tinygrad/tinygrad / train_step

Function train_step

extra/training.py:10–21  ·  view source on GitHub ↗
(x, y)

Source from the content-addressed store, hash-verified

8 transform=lambda x: x, target_transform=lambda x: x, noloss=False, allow_jit=True):
9
10 def train_step(x, y):
11 # network
12 out = model.forward(x) if hasattr(model, 'forward') else model(x)
13 loss = lossfn(out, y)
14 optim.zero_grad()
15 loss.backward()
16 if noloss: del loss
17 optim.step()
18 if noloss: return (None, None)
19 cat = out.argmax(axis=-1)
20 accuracy = (cat == y).mean()
21 return loss.realize(), accuracy.realize()
22
23 if allow_jit: train_step = TinyJit(train_step)
24

Callers 1

trainFunction · 0.70

Calls 8

modelFunction · 0.85
argmaxMethod · 0.80
meanMethod · 0.80
realizeMethod · 0.80
forwardMethod · 0.45
zero_gradMethod · 0.45
backwardMethod · 0.45
stepMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…