(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),
transform=lambda x: x, target_transform=lambda x: x, noloss=False, allow_jit=True)
| 5 | |
| 6 | |
| 7 | def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y), |
| 8 | transform=lambda x: x, target_transform=lambda x: x, noloss=False, allow_jit=True): |
| 9 | |
| 10 | def train_step(x, y): |
| 11 | # network |
| 12 | out = model.forward(x) if hasattr(model, 'forward') else model(x) |
| 13 | loss = lossfn(out, y) |
| 14 | optim.zero_grad() |
| 15 | loss.backward() |
| 16 | if noloss: del loss |
| 17 | optim.step() |
| 18 | if noloss: return (None, None) |
| 19 | cat = out.argmax(axis=-1) |
| 20 | accuracy = (cat == y).mean() |
| 21 | return loss.realize(), accuracy.realize() |
| 22 | |
| 23 | if allow_jit: train_step = TinyJit(train_step) |
| 24 | |
| 25 | with Tensor.train(): |
| 26 | losses, accuracies = [], [] |
| 27 | for i in (t := trange(steps, disable=None)): |
| 28 | samp = np.random.randint(0, X_train.shape[0], size=(BS)) |
| 29 | x = Tensor(transform(X_train[samp])) |
| 30 | y = Tensor(target_transform(Y_train[samp])) |
| 31 | loss, accuracy = train_step(x, y) |
| 32 | # printing |
| 33 | if not noloss: |
| 34 | loss, accuracy = loss.numpy(), accuracy.numpy() |
| 35 | losses.append(loss) |
| 36 | accuracies.append(accuracy) |
| 37 | t.set_description("loss %.2f accuracy %.2f" % (loss, accuracy)) |
| 38 | return [losses, accuracies] |
| 39 | |
| 40 | |
| 41 | def evaluate(model, X_test, Y_test, num_classes=None, BS=128, return_predict=False, transform=lambda x: x, |
searching dependent graphs…