(self, epoch, val_loader, model, classifier, criterion)
| 191 | return top1.avg, top5.avg, losses.avg |
| 192 | |
| 193 | def validate(self, epoch, val_loader, model, classifier, criterion): |
| 194 | time1 = time.time() |
| 195 | args = self.args |
| 196 | |
| 197 | model.eval() |
| 198 | classifier.eval() |
| 199 | |
| 200 | batch_time = AverageMeter() |
| 201 | losses = AverageMeter() |
| 202 | top1 = AverageMeter() |
| 203 | top5 = AverageMeter() |
| 204 | |
| 205 | with torch.no_grad(): |
| 206 | end = time.time() |
| 207 | for idx, (input, target) in enumerate(val_loader): |
| 208 | input = input.float() |
| 209 | input = input.cuda(args.gpu, non_blocking=True) |
| 210 | target = target.cuda(args.gpu, non_blocking=True) |
| 211 | |
| 212 | # compute output |
| 213 | feat = model(x=input, mode=2) |
| 214 | output = classifier(feat) |
| 215 | loss = criterion(output, target) |
| 216 | |
| 217 | # measure accuracy and record loss |
| 218 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
| 219 | losses.update(loss.item(), input.size(0)) |
| 220 | top1.update(acc1[0], input.size(0)) |
| 221 | top5.update(acc5[0], input.size(0)) |
| 222 | |
| 223 | # measure elapsed time |
| 224 | batch_time.update(time.time() - end) |
| 225 | end = time.time() |
| 226 | |
| 227 | if args.local_rank == 0 and idx % args.print_freq == 0: |
| 228 | print('Test: [{0}/{1}]\t' |
| 229 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
| 230 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' |
| 231 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})\t' |
| 232 | 'Acc@5 {top5.val:.3f} ({top5.avg:.3f})'.format( |
| 233 | idx, len(val_loader), batch_time=batch_time, |
| 234 | loss=losses, top1=top1, top5=top5)) |
| 235 | |
| 236 | print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' |
| 237 | .format(top1=top1, top5=top5)) |
| 238 | |
| 239 | time2 = time.time() |
| 240 | print('eval epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) |
| 241 | |
| 242 | return top1.avg, top5.avg, losses.avg |
no test coverage detected