MCPcopy
hub / github.com/tinygrad/tinygrad / train

Function train

extra/training.py:7–38  ·  view source on GitHub ↗
(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
        transform=lambda x: x, target_transform=lambda x: x, noloss=False, allow_jit=True)

Source from the content-addressed store, hash-verified

5
6
7def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
8 transform=lambda x: x, target_transform=lambda x: x, noloss=False, allow_jit=True):
9
10 def train_step(x, y):
11 # network
12 out = model.forward(x) if hasattr(model, 'forward') else model(x)
13 loss = lossfn(out, y)
14 optim.zero_grad()
15 loss.backward()
16 if noloss: del loss
17 optim.step()
18 if noloss: return (None, None)
19 cat = out.argmax(axis=-1)
20 accuracy = (cat == y).mean()
21 return loss.realize(), accuracy.realize()
22
23 if allow_jit: train_step = TinyJit(train_step)
24
25 with Tensor.train():
26 losses, accuracies = [], []
27 for i in (t := trange(steps, disable=None)):
28 samp = np.random.randint(0, X_train.shape[0], size=(BS))
29 x = Tensor(transform(X_train[samp]))
30 y = Tensor(target_transform(Y_train[samp]))
31 loss, accuracy = train_step(x, y)
32 # printing
33 if not noloss:
34 loss, accuracy = loss.numpy(), accuracy.numpy()
35 losses.append(loss)
36 accuracies.append(accuracy)
37 t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy))
38 return [losses, accuracies]
39
40
41def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=False, transform=lambda x: x,

Callers 13

lr_scheduler_trainingFunction · 0.90
test_sgd_onestepMethod · 0.90
test_sgd_threestepMethod · 0.90
test_sgd_sixstepMethod · 0.90
test_adam_onestepMethod · 0.90
test_adam_threestepMethod · 0.90
test_conv_onestepMethod · 0.90
test_convMethod · 0.90
test_conv_with_bnMethod · 0.90
test_sgdMethod · 0.90
train_one_stepFunction · 0.90
train_resnet.pyFile · 0.90

Calls 10

TinyJitClass · 0.90
trangeFunction · 0.90
TensorClass · 0.90
trainMethod · 0.80
randintMethod · 0.80
appendMethod · 0.80
set_descriptionMethod · 0.80
train_stepFunction · 0.70
numpyMethod · 0.45

Tested by 11

lr_scheduler_trainingFunction · 0.72
test_sgd_onestepMethod · 0.72
test_sgd_threestepMethod · 0.72
test_sgd_sixstepMethod · 0.72
test_adam_onestepMethod · 0.72
test_adam_threestepMethod · 0.72
test_conv_onestepMethod · 0.72
test_convMethod · 0.72
test_conv_with_bnMethod · 0.72
test_sgdMethod · 0.72
train_one_stepFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…