validation
(val_loader, model, classifier, criterion, opt)
| 186 | |
| 187 | |
| 188 | def validate(val_loader, model, classifier, criterion, opt): |
| 189 | """validation""" |
| 190 | model.eval() |
| 191 | classifier.eval() |
| 192 | |
| 193 | batch_time = AverageMeter() |
| 194 | losses = AverageMeter() |
| 195 | top1 = AverageMeter() |
| 196 | |
| 197 | with torch.no_grad(): |
| 198 | end = time.time() |
| 199 | for idx, (images, labels) in enumerate(val_loader): |
| 200 | images = images.float().cuda() |
| 201 | labels = labels.cuda() |
| 202 | bsz = labels.shape[0] |
| 203 | |
| 204 | # forward |
| 205 | output = classifier(model.encoder(images)) |
| 206 | loss = criterion(output, labels) |
| 207 | |
| 208 | # update metric |
| 209 | losses.update(loss.item(), bsz) |
| 210 | acc1, acc5 = accuracy(output, labels, topk=(1, 5)) |
| 211 | top1.update(acc1[0], bsz) |
| 212 | |
| 213 | # measure elapsed time |
| 214 | batch_time.update(time.time() - end) |
| 215 | end = time.time() |
| 216 | |
| 217 | if idx % opt.print_freq == 0: |
| 218 | print('Test: [{0}/{1}]\t' |
| 219 | 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' |
| 220 | 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' |
| 221 | 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format( |
| 222 | idx, len(val_loader), batch_time=batch_time, |
| 223 | loss=losses, top1=top1)) |
| 224 | |
| 225 | print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1)) |
| 226 | return losses.avg, top1.avg |
| 227 | |
| 228 | |
| 229 | def main(): |
no test coverage detected