(epoch, model, loader, loss_fn, cfg, log_suffix='', logger=None, writer=None, local_rank=0)
| 11 | |
| 12 | # validate function |
| 13 | def validate(epoch, model, loader, loss_fn, cfg, log_suffix='', logger=None, writer=None, local_rank=0): |
| 14 | batch_time_m = AverageMeter() |
| 15 | losses_m = AverageMeter() |
| 16 | prec1_m = AverageMeter() |
| 17 | prec5_m = AverageMeter() |
| 18 | |
| 19 | model.eval() |
| 20 | |
| 21 | end = time.time() |
| 22 | last_idx = len(loader) - 1 |
| 23 | with torch.no_grad(): |
| 24 | for batch_idx, (input, target) in enumerate(loader): |
| 25 | last_batch = batch_idx == last_idx |
| 26 | |
| 27 | output = model(input) |
| 28 | if isinstance(output, (tuple, list)): |
| 29 | output = output[0] |
| 30 | |
| 31 | # augmentation reduction |
| 32 | reduce_factor = cfg.TTA |
| 33 | if reduce_factor > 1: |
| 34 | output = output.unfold( |
| 35 | 0, |
| 36 | reduce_factor, |
| 37 | reduce_factor).mean( |
| 38 | dim=2) |
| 39 | target = target[0:target.size(0):reduce_factor] |
| 40 | |
| 41 | loss = loss_fn(output, target) |
| 42 | prec1, prec5 = accuracy(output, target, topk=(1, 5)) |
| 43 | |
| 44 | if cfg.NUM_GPU > 1: |
| 45 | reduced_loss = reduce_tensor(loss.data, cfg.NUM_GPU) |
| 46 | prec1 = reduce_tensor(prec1, cfg.NUM_GPU) |
| 47 | prec5 = reduce_tensor(prec5, cfg.NUM_GPU) |
| 48 | else: |
| 49 | reduced_loss = loss.data |
| 50 | |
| 51 | torch.cuda.synchronize() |
| 52 | |
| 53 | losses_m.update(reduced_loss.item(), input.size(0)) |
| 54 | prec1_m.update(prec1.item(), output.size(0)) |
| 55 | prec5_m.update(prec5.item(), output.size(0)) |
| 56 | |
| 57 | batch_time_m.update(time.time() - end) |
| 58 | end = time.time() |
| 59 | if local_rank == 0 and (last_batch or batch_idx % cfg.LOG_INTERVAL == 0): |
| 60 | log_name = 'Test' + log_suffix |
| 61 | logger.info( |
| 62 | '{0}: [{1:>4d}/{2}] ' |
| 63 | 'Time: {batch_time.val:.3f} ({batch_time.avg:.3f}) ' |
| 64 | 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' |
| 65 | 'Prec@1: {top1.val:>7.4f} ({top1.avg:>7.4f}) ' |
| 66 | 'Prec@5: {top5.val:>7.4f} ({top5.avg:>7.4f})'.format( |
| 67 | log_name, batch_idx, last_idx, |
| 68 | batch_time=batch_time_m, loss=losses_m, |
| 69 | top1=prec1_m, top5=prec5_m)) |
| 70 |
no test coverage detected