MCPcopy
hub / github.com/DingXiaoH/RepVGG / validate

Function validate

quantization/quant_qat_train.py:373–413  ·  view source on GitHub ↗
(val_loader, model, criterion, args)

Source from the content-addressed store, hash-verified

371
372
373def validate(val_loader, model, criterion, args):
374 batch_time = AverageMeter('Time', ':6.3f')
375 losses = AverageMeter('Loss', ':.4e')
376 top1 = AverageMeter('Acc@1', ':6.2f')
377 top5 = AverageMeter('Acc@5', ':6.2f')
378 progress = ProgressMeter(
379 len(val_loader),
380 [batch_time, losses, top1, top5],
381 prefix='Test: ')
382
383 # switch to evaluate mode
384 model.eval()
385
386 with torch.no_grad():
387 end = time.time()
388 for i, (images, target) in enumerate(val_loader):
389 images = images.cuda(args.gpu, non_blocking=True)
390 target = target.cuda(args.gpu, non_blocking=True)
391
392 # compute output
393 output = model(images)
394 loss = criterion(output, target)
395
396 # measure accuracy and record loss
397 acc1, acc5 = accuracy(output, target, topk=(1, 5))
398 losses.update(loss.item(), images.size(0))
399 top1.update(acc1[0], images.size(0))
400 top5.update(acc5[0], images.size(0))
401
402 # measure elapsed time
403 batch_time.update(time.time() - end)
404 end = time.time()
405
406 if i % args.print_freq == 0:
407 progress.display(i)
408
409 # TODO: this should also be done with the ProgressMeter
410 print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
411 .format(top1=top1, top5=top5))
412
413 return top1.avg
414
415
416def save_checkpoint(state, is_best, filename, best_filename):

Callers 1

main_workerFunction · 0.70

Calls 5

updateMethod · 0.95
displayMethod · 0.95
AverageMeterClass · 0.85
ProgressMeterClass · 0.85
accuracyFunction · 0.85

Tested by

no test coverage detected