validation
(val_loader, model, criterion, opt)
| 238 | |
| 239 | |
| 240 | def validate(val_loader, model, criterion, opt): |
| 241 | """validation""" |
| 242 | model.eval() |
| 243 | |
| 244 | batch_time = AverageMeter() |
| 245 | losses = AverageMeter() |
| 246 | top1 = AverageMeter() |
| 247 | |
| 248 | with torch.no_grad(): |
| 249 | end = time.time() |
| 250 | for idx, (images, labels) in enumerate(val_loader): |
| 251 | images = images.float().cuda() |
| 252 | labels = labels.cuda() |
| 253 | bsz = labels.shape[0] |
| 254 | |
| 255 | # forward |
| 256 | output = model(images) |
| 257 | loss = criterion(output, labels) |
| 258 | |
| 259 | # update metric |
| 260 | losses.update(loss.item(), bsz) |
| 261 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) |
| 262 | top1.update(acc1[0], bsz) |
| 263 | |
| 264 | # measure elapsed time |
| 265 | batch_time.update(time.time() - end) |
| 266 | end = time.time() |
| 267 | |
| 268 | if idx % opt.print_freq == 0: |
| 269 | print('Test: [{0}/{1}]\t' |
| 270 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
| 271 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' |
| 272 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( |
| 273 | idx, len(val_loader), batch_time=batch_time, |
| 274 | loss=losses, top1=top1)) |
| 275 | |
| 276 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) |
| 277 | return losses.avg, top1.avg |
| 278 | |
| 279 | |
| 280 | def main(): |
no test coverage detected