MCPcopy
hub / github.com/HobbitLong/SupContrast / validate

Function validate

main_linear.py:188–226  ·  view source on GitHub ↗

validation

(val_loader, model, classifier, criterion, opt)

Source from the content-addressed store, hash-verified

186
187
188def validate(val_loader, model, classifier, criterion, opt):
189 """validation"""
190 model.eval()
191 classifier.eval()
192
193 batch_time = AverageMeter()
194 losses = AverageMeter()
195 top1 = AverageMeter()
196
197 with torch.no_grad():
198 end = time.time()
199 for idx, (images, labels) in enumerate(val_loader):
200 images = images.float().cuda()
201 labels = labels.cuda()
202 bsz = labels.shape[0]
203
204 # forward
205 output = classifier(model.encoder(images))
206 loss = criterion(output, labels)
207
208 # update metric
209 losses.update(loss.item(), bsz)
210 acc1, acc5 = accuracy(output, labels, topk=(1, 5))
211 top1.update(acc1[0], bsz)
212
213 # measure elapsed time
214 batch_time.update(time.time() - end)
215 end = time.time()
216
217 if idx % opt.print_freq == 0:
218 print('Test: [{0}/{1}]\t'
219 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
220 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
221 'Acc@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
222 idx, len(val_loader), batch_time=batch_time,
223 loss=losses, top1=top1))
224
225 print(' * Acc@1 {top1.avg:.3f}'.format(top1=top1))
226 return losses.avg, top1.avg
227
228
229def main():

Callers 1

mainFunction · 0.70

Calls 3

updateMethod · 0.95
AverageMeterClass · 0.90
accuracyFunction · 0.90

Tested by

no test coverage detected