| 80 | epoch+1, config.epochs, top1.avg)) |
| 81 | |
| 82 | def validate(valid_loader, model, epoch, cur_step, writer, logger, config): |
| 83 | top1 = utils.AverageMeter() |
| 84 | top5 = utils.AverageMeter() |
| 85 | losses = utils.AverageMeter() |
| 86 | |
| 87 | model.eval() |
| 88 | device = torch.device("cuda") |
| 89 | criterion = nn.CrossEntropyLoss().to(device) |
| 90 | |
| 91 | with torch.no_grad(): |
| 92 | for step, (X, y) in enumerate(valid_loader): |
| 93 | X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True) |
| 94 | N = X.size(0) |
| 95 | |
| 96 | logits, _ = model(X) |
| 97 | loss = criterion(logits, y) |
| 98 | |
| 99 | prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5)) |
| 100 | |
| 101 | if config.distributed: |
| 102 | reduced_loss = utils.reduce_tensor(loss.data, config.world_size) |
| 103 | prec1 = utils.reduce_tensor(prec1, config.world_size) |
| 104 | prec5 = utils.reduce_tensor(prec5, config.world_size) |
| 105 | else: |
| 106 | reduced_loss = loss.data |
| 107 | |
| 108 | losses.update(reduced_loss.item(), N) |
| 109 | top1.update(prec1.item(), N) |
| 110 | top5.update(prec5.item(), N) |
| 111 | |
| 112 | torch.cuda.synchronize() |
| 113 | step_num = len(valid_loader) |
| 114 | |
| 115 | if (step % config.print_freq == 0 or step == step_num-1) and config.local_rank == 0: |
| 116 | logger.info( |
| 117 | "Valid: Epoch {:2d}/{} Step {:03d}/{:03d} Loss {losses.avg:.3f} " |
| 118 | "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format( |
| 119 | epoch+1, config.epochs, step, step_num, |
| 120 | losses=losses, top1=top1, top5=top5)) |
| 121 | |
| 122 | if config.local_rank == 0: |
| 123 | writer.add_scalar('val/loss', losses.avg, cur_step) |
| 124 | writer.add_scalar('val/top1', top1.avg, cur_step) |
| 125 | writer.add_scalar('val/top5', top5.avg, cur_step) |
| 126 | |
| 127 | logger.info("Valid: Epoch {:2d}/{} Final Prec@1 {:.4%}".format( |
| 128 | epoch+1, config.epochs, top1.avg)) |
| 129 | |
| 130 | return top1.avg, top5.avg |