(x, y)
| 8 | transform=lambda x: x, target_transform=lambda x: x, noloss=False, allow_jit=True): |
| 9 | |
| 10 | def train_step(x, y): |
| 11 | # network |
| 12 | out = model.forward(x) if hasattr(model, 'forward') else model(x) |
| 13 | loss = lossfn(out, y) |
| 14 | optim.zero_grad() |
| 15 | loss.backward() |
| 16 | if noloss: del loss |
| 17 | optim.step() |
| 18 | if noloss: return (None, None) |
| 19 | cat = out.argmax(axis=-1) |
| 20 | accuracy = (cat == y).mean() |
| 21 | return loss.realize(), accuracy.realize() |
| 22 | |
| 23 | if allow_jit: train_step = TinyJit(train_step) |
| 24 |
no test coverage detected
searching dependent graphs…