(config, val_loader, dataset, converter, model, criterion, device, epoch, writer_dict, output_dict)
| 77 | |
| 78 | |
| 79 | def validate(config, val_loader, dataset, converter, model, criterion, device, epoch, writer_dict, output_dict): |
| 80 | |
| 81 | losses = AverageMeter() |
| 82 | model.eval() |
| 83 | |
| 84 | n_correct = 0 |
| 85 | with torch.no_grad(): |
| 86 | for i, (inp, idx) in enumerate(val_loader): |
| 87 | |
| 88 | labels = utils.get_batch_label(dataset, idx) |
| 89 | inp = inp.to(device) |
| 90 | |
| 91 | # inference |
| 92 | preds = model(inp).cpu() |
| 93 | |
| 94 | # compute loss |
| 95 | batch_size = inp.size(0) |
| 96 | text, length = converter.encode(labels) |
| 97 | preds_size = torch.IntTensor([preds.size(0)] * batch_size) |
| 98 | loss = criterion(preds, text, preds_size, length) |
| 99 | |
| 100 | losses.update(loss.item(), inp.size(0)) |
| 101 | |
| 102 | _, preds = preds.max(2) |
| 103 | preds = preds.transpose(1, 0).contiguous().view(-1) |
| 104 | sim_preds = converter.decode(preds.data, preds_size.data, raw=False) |
| 105 | for pred, target in zip(sim_preds, labels): |
| 106 | if pred == target: |
| 107 | n_correct += 1 |
| 108 | |
| 109 | if (i + 1) % config.PRINT_FREQ == 0: |
| 110 | print('Epoch: [{0}][{1}/{2}]'.format(epoch, i, len(val_loader))) |
| 111 | |
| 112 | if i == config.TEST.NUM_TEST_BATCH: |
| 113 | break |
| 114 | |
| 115 | raw_preds = converter.decode(preds.data, preds_size.data, raw=True)[:config.TEST.NUM_TEST_DISP] |
| 116 | for raw_pred, pred, gt in zip(raw_preds, sim_preds, labels): |
| 117 | print('%-20s => %-20s, gt: %-20s' % (raw_pred, pred, gt)) |
| 118 | |
| 119 | num_test_sample = config.TEST.NUM_TEST_BATCH * config.TEST.BATCH_SIZE_PER_GPU |
| 120 | if num_test_sample > len(dataset): |
| 121 | num_test_sample = len(dataset) |
| 122 | |
| 123 | print("[#correct:{} / #total:{}]".format(n_correct, num_test_sample)) |
| 124 | accuracy = n_correct / float(num_test_sample) |
| 125 | print('Test loss: {:.4f}, accuray: {:.4f}'.format(losses.avg, accuracy)) |
| 126 | |
| 127 | if writer_dict: |
| 128 | writer = writer_dict['writer'] |
| 129 | global_steps = writer_dict['valid_global_steps'] |
| 130 | writer.add_scalar('valid_acc', accuracy, global_steps) |
| 131 | writer_dict['valid_global_steps'] = global_steps + 1 |
| 132 | |
| 133 | return accuracy |
nothing calls this directly
no test coverage detected