MCPcopy
hub / github.com/microsoft/Cream / validate

Function validate

CDARTS/lib/core/augment_function.py:82–130  ·  view source on GitHub ↗
(valid_loader, model, epoch, cur_step, writer, logger, config)

Source from the content-addressed store, hash-verified

80 epoch+1, config.epochs, top1.avg))
81
82def validate(valid_loader, model, epoch, cur_step, writer, logger, config):
83 top1 = utils.AverageMeter()
84 top5 = utils.AverageMeter()
85 losses = utils.AverageMeter()
86
87 model.eval()
88 device = torch.device("cuda")
89 criterion = nn.CrossEntropyLoss().to(device)
90
91 with torch.no_grad():
92 for step, (X, y) in enumerate(valid_loader):
93 X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
94 N = X.size(0)
95
96 logits, _ = model(X)
97 loss = criterion(logits, y)
98
99 prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
100
101 if config.distributed:
102 reduced_loss = utils.reduce_tensor(loss.data, config.world_size)
103 prec1 = utils.reduce_tensor(prec1, config.world_size)
104 prec5 = utils.reduce_tensor(prec5, config.world_size)
105 else:
106 reduced_loss = loss.data
107
108 losses.update(reduced_loss.item(), N)
109 top1.update(prec1.item(), N)
110 top5.update(prec5.item(), N)
111
112 torch.cuda.synchronize()
113 step_num = len(valid_loader)
114
115 if (step % config.print_freq == 0 or step == step_num-1) and config.local_rank == 0:
116 logger.info(
117 "Valid: Epoch {:2d}/{} Step {:03d}/{:03d} Loss {losses.avg:.3f} "
118 "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
119 epoch+1, config.epochs, step, step_num,
120 losses=losses, top1=top1, top5=top5))
121
122 if config.local_rank == 0:
123 writer.add_scalar('val/loss', losses.avg, cur_step)
124 writer.add_scalar('val/top1', top1.avg, cur_step)
125 writer.add_scalar('val/top5', top5.avg, cur_step)
126
127 logger.info("Valid: Epoch {:2d}/{} Final Prec@1 {:.4%}".format(
128 epoch+1, config.epochs, top1.avg))
129
130 return top1.avg, top5.avg

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls 4

updateMethod · 0.95
toMethod · 0.80
formatMethod · 0.80
sizeMethod · 0.45

Tested by 1

mainFunction · 0.72