(trainer)
| 179 | # iteration callback |
| 180 | top_score = 0 |
| 181 | def batch_end_callback(trainer): |
| 182 | global top_score |
| 183 | |
| 184 | if trainer.iter_num % 10 == 0: |
| 185 | print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}") |
| 186 | |
| 187 | if trainer.iter_num % 500 == 0: |
| 188 | # evaluate both the train and test score |
| 189 | train_max_batches = {1: None, 2: None, 3: 5}[config.data.ndigit] # if ndigit=2 we can afford the whole train set, ow no |
| 190 | model.eval() |
| 191 | with torch.no_grad(): |
| 192 | train_score = eval_split(trainer, 'train', max_batches=train_max_batches) |
| 193 | test_score = eval_split(trainer, 'test', max_batches=None) |
| 194 | score = train_score + test_score |
| 195 | # save the model if this is the best score we've seen so far |
| 196 | if score > top_score: |
| 197 | top_score = score |
| 198 | print(f"saving model with new top score of {score}") |
| 199 | ckpt_path = os.path.join(config.system.work_dir, "model.pt") |
| 200 | torch.save(model.state_dict(), ckpt_path) |
| 201 | # revert model to training mode |
| 202 | model.train() |
| 203 | |
| 204 | trainer.set_callback('on_batch_end', batch_end_callback) |
| 205 |
nothing calls this directly
no test coverage detected