MCPcopy
hub / github.com/microsoft/Cream / validate

Function validate

CDARTS/benchmark201/core/pretrain_function.py:81–129  ·  view source on GitHub ↗
(valid_loader, model, epoch, cur_step, writer, logger, config)

Source from the content-addressed store, hash-verified

79 epoch+1, config.epochs, top1.avg))
80
81def validate(valid_loader, model, epoch, cur_step, writer, logger, config):
82 top1 = utils.AverageMeter()
83 top5 = utils.AverageMeter()
84 losses = utils.AverageMeter()
85
86 model.eval()
87 device = torch.device("cuda")
88 criterion = nn.CrossEntropyLoss().to(device)
89
90 with torch.no_grad():
91 for step, (X, y) in enumerate(valid_loader):
92 X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
93 N = X.size(0)
94
95 logits, _ = model(X, layer_idx=0, super_flag=True, pretrain_flag=True)
96 loss = criterion(logits, y)
97
98 prec1, prec5 = utils.accuracy(logits, y, topk=(1, 5))
99
100 if config.distributed:
101 reduced_loss = utils.reduce_tensor(loss.data, config.world_size)
102 prec1 = utils.reduce_tensor(prec1, config.world_size)
103 prec5 = utils.reduce_tensor(prec5, config.world_size)
104 else:
105 reduced_loss = loss.data
106
107 losses.update(reduced_loss.item(), N)
108 top1.update(prec1.item(), N)
109 top5.update(prec5.item(), N)
110
111 torch.cuda.synchronize()
112 step_num = len(valid_loader)
113
114 if (step % config.print_freq == 0 or step == step_num-1) and config.local_rank == 0:
115 logger.info(
116 "Valid: Epoch {:2d}/{} Step {:03d}/{:03d} Loss {losses.avg:.3f} "
117 "Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
118 epoch+1, config.epochs, step, step_num,
119 losses=losses, top1=top1, top5=top5))
120
121 if config.local_rank == 0:
122 writer.add_scalar('val/loss', losses.avg, cur_step)
123 writer.add_scalar('val/top1', top1.avg, cur_step)
124 writer.add_scalar('val/top5', top5.avg, cur_step)
125
126 logger.info("Valid: Epoch {:2d}/{} Final Prec@1 {:.4%}".format(
127 epoch+1, config.epochs, top1.avg))
128
129 return top1.avg, top5.avg
130
131
132def sample_train(train_loader, model, optimizer, epoch, writer, logger, config):

Callers

nothing calls this directly

Calls 4

updateMethod · 0.95
toMethod · 0.80
formatMethod · 0.80
sizeMethod · 0.45

Tested by

no test coverage detected