(epoch)
| 111 | |
| 112 | |
| 113 | def test(epoch): |
| 114 | model.eval() |
| 115 | test_loss = 0 |
| 116 | with torch.no_grad(): |
| 117 | for i, (data, _) in enumerate(test_loader): |
| 118 | data = data.to(device) |
| 119 | recon_batch, mu, logvar = model(data) |
| 120 | test_loss += loss_function(recon_batch, data, mu, logvar).item() |
| 121 | if i == 0: |
| 122 | n = min(data.size(0), 8) |
| 123 | comparison = torch.cat([data[:n], |
| 124 | recon_batch.view(args.batch_size, 1, 28, 28)[:n]]) |
| 125 | save_image(comparison.cpu(), |
| 126 | 'results/reconstruction_' + str(epoch) + '.png', nrow=n) |
| 127 | |
| 128 | test_loss /= len(test_loader.dataset) |
| 129 | print('====> Test set loss: {:.4f}'.format(test_loss)) |
| 130 | |
| 131 | if __name__ == "__main__": |
| 132 | for epoch in range(1, args.epochs + 1): |
no test coverage detected