| 67 | |
| 68 | |
| 69 | def test(model, data_tar, e): |
| 70 | total_loss_test = 0 |
| 71 | correct = 0 |
| 72 | criterion = nn.CrossEntropyLoss() |
| 73 | with torch.no_grad(): |
| 74 | for batch_id, (data, target) in enumerate(data_tar): |
| 75 | data, target = data.view(-1,28 * 28).to(DEVICE),target.to(DEVICE) |
| 76 | model.eval() |
| 77 | ypred, _, _ = model(data, data) |
| 78 | loss = criterion(ypred, target) |
| 79 | pred = ypred.data.max(1)[1] # get the index of the max log-probability |
| 80 | correct += pred.eq(target.data.view_as(pred)).cpu().sum() |
| 81 | total_loss_test += loss.data |
| 82 | accuracy = correct * 100. / len(data_tar.dataset) |
| 83 | res = 'Test: total loss: {:.6f}, correct: [{}/{}], testing accuracy: {:.4f}%'.format( |
| 84 | total_loss_test, correct, len(data_tar.dataset), accuracy |
| 85 | ) |
| 86 | tqdm.write(res) |
| 87 | RESULT_TEST.append([e, total_loss_test, accuracy]) |
| 88 | log_test.write(res + '\n') |
| 89 | |
| 90 | |
| 91 | if __name__ == '__main__': |