(val_loader, model, criterion, args)
| 371 | |
| 372 | |
| 373 | def validate(val_loader, model, criterion, args): |
| 374 | batch_time = AverageMeter('Time', ':6.3f') |
| 375 | losses = AverageMeter('Loss', ':.4e') |
| 376 | top1 = AverageMeter('Acc@1', ':6.2f') |
| 377 | top5 = AverageMeter('Acc@5', ':6.2f') |
| 378 | progress = ProgressMeter( |
| 379 | len(val_loader), |
| 380 | [batch_time, losses, top1, top5], |
| 381 | prefix='Test: ') |
| 382 | |
| 383 | # switch to evaluate mode |
| 384 | model.eval() |
| 385 | |
| 386 | with torch.no_grad(): |
| 387 | end = time.time() |
| 388 | for i, (images, target) in enumerate(val_loader): |
| 389 | images = images.cuda(args.gpu, non_blocking=True) |
| 390 | target = target.cuda(args.gpu, non_blocking=True) |
| 391 | |
| 392 | # compute output |
| 393 | output = model(images) |
| 394 | loss = criterion(output, target) |
| 395 | |
| 396 | # measure accuracy and record loss |
| 397 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
| 398 | losses.update(loss.item(), images.size(0)) |
| 399 | top1.update(acc1[0], images.size(0)) |
| 400 | top5.update(acc5[0], images.size(0)) |
| 401 | |
| 402 | # measure elapsed time |
| 403 | batch_time.update(time.time() - end) |
| 404 | end = time.time() |
| 405 | |
| 406 | if i % args.print_freq == 0: |
| 407 | progress.display(i) |
| 408 | |
| 409 | # TODO: this should also be done with the ProgressMeter |
| 410 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' |
| 411 | .format(top1=top1, top5=top5)) |
| 412 | |
| 413 | return top1.avg |
| 414 | |
| 415 | |
| 416 | def save_checkpoint(state, is_best, filename, best_filename): |
no test coverage detected