| 486 | |
| 487 | @torch.inference_mode() |
| 488 | def evaluate(model, dataset, batch_size=50, max_batches=None): |
| 489 | model.eval() |
| 490 | loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=0) |
| 491 | losses = [] |
| 492 | for i, batch in enumerate(loader): |
| 493 | batch = [t.to(args.device) for t in batch] |
| 494 | X, Y = batch |
| 495 | logits, loss = model(X, Y) |
| 496 | losses.append(loss.item()) |
| 497 | if max_batches is not None and i >= max_batches: |
| 498 | break |
| 499 | mean_loss = torch.tensor(losses).mean().item() |
| 500 | model.train() # reset model back to training mode |
| 501 | return mean_loss |
| 502 | |
| 503 | # ----------------------------------------------------------------------------- |
| 504 | # helper functions for creating the training and test Datasets that emit words |