| 402 | |
| 403 | @torch.no_grad() |
| 404 | def validate(args, config, data_loader, model, num_classes=1000): |
| 405 | criterion = torch.nn.CrossEntropyLoss() |
| 406 | model.eval() |
| 407 | |
| 408 | batch_time = AverageMeter() |
| 409 | loss_meter = AverageMeter() |
| 410 | acc1_meter = AverageMeter() |
| 411 | acc5_meter = AverageMeter() |
| 412 | |
| 413 | end = time.time() |
| 414 | for idx, (images, target) in enumerate(data_loader): |
| 415 | if not args.only_cpu: |
| 416 | images = images.cuda(non_blocking=True) |
| 417 | target = target.cuda(non_blocking=True) |
| 418 | |
| 419 | # compute output |
| 420 | with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE): |
| 421 | output = model(images) |
| 422 | if num_classes == 1000: |
| 423 | output_num_classes = output.size(-1) |
| 424 | if output_num_classes == 21841: |
| 425 | output = remap_layer_22kto1k(output) |
| 426 | |
| 427 | # measure accuracy and record loss |
| 428 | loss = criterion(output, target) |
| 429 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
| 430 | |
| 431 | loss_meter.update(loss.item(), target.size(0)) |
| 432 | acc1_meter.update(acc1.item(), target.size(0)) |
| 433 | acc5_meter.update(acc5.item(), target.size(0)) |
| 434 | |
| 435 | # measure elapsed time |
| 436 | batch_time.update(time.time() - end) |
| 437 | end = time.time() |
| 438 | |
| 439 | if idx % config.PRINT_FREQ == 0: |
| 440 | memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) |
| 441 | logger.info( |
| 442 | f'Test: [{idx}/{len(data_loader)}]\t' |
| 443 | f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
| 444 | f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' |
| 445 | f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' |
| 446 | f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' |
| 447 | f'Mem {memory_used:.0f}MB') |
| 448 | |
| 449 | acc1_meter.sync() |
| 450 | acc5_meter.sync() |
| 451 | logger.info( |
| 452 | f' The number of validation samples is {int(acc1_meter.count)}') |
| 453 | logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') |
| 454 | return acc1_meter.avg, acc5_meter.avg, loss_meter.avg |
| 455 | |
| 456 | |
| 457 | @torch.no_grad() |