MCPcopy
hub / github.com/HobbitLong/SupContrast / validate

Function validate

main_ce.py:240–277  ·  view source on GitHub ↗

validation

(val_loader, model, criterion, opt)

Source from the content-addressed store, hash-verified

238
239
240def validate(val_loader, model, criterion, opt):
241 """validation"""
242 model.eval()
243
244 batch_time = AverageMeter()
245 losses = AverageMeter()
246 top1 = AverageMeter()
247
248 with torch.no_grad():
249 end = time.time()
250 for idx, (images, labels) in enumerate(val_loader):
251 images = images.float().cuda()
252 labels = labels.cuda()
253 bsz = labels.shape[0]
254
255 # forward
256 output = model(images)
257 loss = criterion(output, labels)
258
259 # update metric
260 losses.update(loss.item(), bsz)
261 acc1, acc5 = accuracy(output, labels, topk=(1, 5))
262 top1.update(acc1[0], bsz)
263
264 # measure elapsed time
265 batch_time.update(time.time() - end)
266 end = time.time()
267
268 if idx % opt.print_freq == 0:
269 print('Test: [{0}/{1}]\t'
270 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
271 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
272 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
273 idx, len(val_loader), batch_time=batch_time,
274 loss=losses, top1=top1))
275
276 print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
277 return losses.avg, top1.avg
278
279
280def main():

Callers 1

mainFunction · 0.70

Calls 3

updateMethod · 0.95
AverageMeterClass · 0.90
accuracyFunction · 0.90

Tested by

no test coverage detected