MCPcopy
hub / github.com/black0017/MedicalZooPytorch / validation

Function validation

lib/train/train_covid.py:39–70  ·  view source on GitHub ↗
(args, model, testloader, epoch, writer)

Source from the content-addressed store, hash-verified

37
38
39def validation(args, model, testloader, epoch, writer):
40 model.eval()
41 criterion = nn.CrossEntropyLoss(reduction='mean')
42
43 metric_ftns = ['loss', 'correct', 'accuracy']
44 val_metrics = MetricTracker(*[m for m in metric_ftns], writer=writer, mode='val')
45 val_metrics.reset()
46 confusion_matrix = torch.zeros(args.classes, args.classes)
47 with torch.no_grad():
48 for batch_idx, input_tensors in enumerate(testloader):
49
50 input_data, target = input_tensors
51 if (args.cuda):
52 input_data = input_data.cuda()
53 target = target.cuda()
54
55 output = model(input_data)
56
57 loss = criterion(output, target)
58
59 correct, total, acc = accuracy(output, target)
60 num_samples = batch_idx * args.batchSz + 1
61 _, preds = torch.max(output, 1)
62 for t, p in zip(target.cpu().view(-1), preds.cpu().view(-1)):
63 confusion_matrix[t.long(), p.long()] += 1
64 val_metrics.update_all_metrics(batch_idx + 1, {'loss': loss.item(), 'accuracy': acc},
65 writer_step=(epoch - 1) * len(testloader) + batch_idx)
66
67 # val_metrics.display_terminal(num_samples/len(testloader),epoch,'VAL')
68 val_metrics.display_terminal(num_samples / len(testloader), epoch, 'VAL', summary=True)
69 print('Confusion Matrix\n{}'.format(confusion_matrix.cpu().numpy()))
70 return val_metrics, confusion_matrix

Callers 1

mainFunction · 0.90

Calls 5

resetMethod · 0.95
update_all_metricsMethod · 0.95
display_terminalMethod · 0.95
MetricTrackerClass · 0.90
accuracyFunction · 0.90

Tested by 1

mainFunction · 0.72