MCPcopy
hub / github.com/microsoft/Cream / validate

Function validate

TinyViT/main.py:404–454  ·  view source on GitHub ↗
(args, config, data_loader, model, num_classes=1000)

Source from the content-addressed store, hash-verified

402
403@torch.no_grad()
404def validate(args, config, data_loader, model, num_classes=1000):
405 criterion = torch.nn.CrossEntropyLoss()
406 model.eval()
407
408 batch_time = AverageMeter()
409 loss_meter = AverageMeter()
410 acc1_meter = AverageMeter()
411 acc5_meter = AverageMeter()
412
413 end = time.time()
414 for idx, (images, target) in enumerate(data_loader):
415 if not args.only_cpu:
416 images = images.cuda(non_blocking=True)
417 target = target.cuda(non_blocking=True)
418
419 # compute output
420 with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
421 output = model(images)
422 if num_classes == 1000:
423 output_num_classes = output.size(-1)
424 if output_num_classes == 21841:
425 output = remap_layer_22kto1k(output)
426
427 # measure accuracy and record loss
428 loss = criterion(output, target)
429 acc1, acc5 = accuracy(output, target, topk=(1, 5))
430
431 loss_meter.update(loss.item(), target.size(0))
432 acc1_meter.update(acc1.item(), target.size(0))
433 acc5_meter.update(acc5.item(), target.size(0))
434
435 # measure elapsed time
436 batch_time.update(time.time() - end)
437 end = time.time()
438
439 if idx % config.PRINT_FREQ == 0:
440 memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
441 logger.info(
442 f'Test: [{idx}/{len(data_loader)}]\t'
443 f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
444 f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
445 f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
446 f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
447 f'Mem {memory_used:.0f}MB')
448
449 acc1_meter.sync()
450 acc5_meter.sync()
451 logger.info(
452 f' The number of validation samples is {int(acc1_meter.count)}')
453 logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
454 return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
455
456
457@torch.no_grad()

Callers 1

mainFunction · 0.70

Calls 5

updateMethod · 0.95
syncMethod · 0.95
AverageMeterClass · 0.90
accuracyFunction · 0.90
sizeMethod · 0.45

Tested by

no test coverage detected