MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / validate

Function validate

RecognizaitonNetwork/crnn/func.py:79–133  ·  view source on GitHub ↗
(config, val_loader, dataset, converter, model, criterion, device, epoch, writer_dict, output_dict)

Source from the content-addressed store, hash-verified

77
78
79def validate(config, val_loader, dataset, converter, model, criterion, device, epoch, writer_dict, output_dict):
80
81 losses = AverageMeter()
82 model.eval()
83
84 n_correct = 0
85 with torch.no_grad():
86 for i, (inp, idx) in enumerate(val_loader):
87
88 labels = utils.get_batch_label(dataset, idx)
89 inp = inp.to(device)
90
91 # inference
92 preds = model(inp).cpu()
93
94 # compute loss
95 batch_size = inp.size(0)
96 text, length = converter.encode(labels)
97 preds_size = torch.IntTensor([preds.size(0)] * batch_size)
98 loss = criterion(preds, text, preds_size, length)
99
100 losses.update(loss.item(), inp.size(0))
101
102 _, preds = preds.max(2)
103 preds = preds.transpose(1, 0).contiguous().view(-1)
104 sim_preds = converter.decode(preds.data, preds_size.data, raw=False)
105 for pred, target in zip(sim_preds, labels):
106 if pred == target:
107 n_correct += 1
108
109 if (i + 1) % config.PRINT_FREQ == 0:
110 print('Epoch: [{0}][{1}/{2}]'.format(epoch, i, len(val_loader)))
111
112 if i == config.TEST.NUM_TEST_BATCH:
113 break
114
115 raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.TEST.NUM_TEST_DISP]
116 for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels):
117 print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt))
118
119 num_test_sample = config.TEST.NUM_TEST_BATCH * config.TEST.BATCH_SIZE_PER_GPU
120 if num_test_sample > len(dataset):
121 num_test_sample = len(dataset)
122
123 print("[#correct:{} / #total:{}]".format(n_correct, num_test_sample))
124 accuracy = n_correct / float(num_test_sample)
125 print('Test loss: {:.4f}, accuray: {:.4f}'.format(losses.avg, accuracy))
126
127 if writer_dict:
128 writer = writer_dict['writer']
129 global_steps = writer_dict['valid_global_steps']
130 writer.add_scalar('valid_acc', accuracy, global_steps)
131 writer_dict['valid_global_steps'] = global_steps + 1
132
133 return accuracy

Callers

nothing calls this directly

Calls 5

updateMethod · 0.95
AverageMeterClass · 0.85
modelFunction · 0.85
encodeMethod · 0.45
decodeMethod · 0.45

Tested by

no test coverage detected