MCPcopy
hub / github.com/microsoft/Cream / validate

Function validate

TinyViT/save_logits.py:238–286  ·  view source on GitHub ↗
(config, data_loader, model, num_classes=1000)

Source from the content-addressed store, hash-verified

236
237@torch.no_grad()
238def validate(config, data_loader, model, num_classes=1000):
239 criterion = torch.nn.CrossEntropyLoss()
240 model.eval()
241
242 batch_time = AverageMeter()
243 loss_meter = AverageMeter()
244 acc1_meter = AverageMeter()
245 acc5_meter = AverageMeter()
246
247 end = time.time()
248 for idx, (images, target) in enumerate(data_loader):
249 images = images.cuda(non_blocking=True)
250 target = target.cuda(non_blocking=True)
251
252 # compute output
253 with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
254 output = model(images)
255
256 if num_classes == 1000:
257 output_num_classes = output.size(-1)
258 if output_num_classes == 21841:
259 output = remap_layer_22kto1k(output)
260
261 # measure accuracy and record loss
262 loss = criterion(output, target)
263 acc1, acc5 = accuracy(output, target, topk=(1, 5))
264
265 loss_meter.update(loss.item(), target.size(0))
266 acc1_meter.update(acc1.item(), target.size(0))
267 acc5_meter.update(acc5.item(), target.size(0))
268
269 # measure elapsed time
270 batch_time.update(time.time() - end)
271 end = time.time()
272
273 if idx % config.PRINT_FREQ == 0:
274 memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
275 logger.info(
276 f'Test: [{idx}/{len(data_loader)}]\t'
277 f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
278 f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
279 f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
280 f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
281 f'Mem {memory_used:.0f}MB')
282
283 acc1_meter.sync()
284 acc5_meter.sync()
285 logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
286 return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
287
288
289if __name__ == '__main__':

Callers 1

mainFunction · 0.70

Calls 5

updateMethod · 0.95
syncMethod · 0.95
AverageMeterClass · 0.90
accuracyFunction · 0.90
sizeMethod · 0.45

Tested by

no test coverage detected