(data_loader, model, device, use_amp=False)
| 136 | |
| 137 | @torch.no_grad() |
| 138 | def evaluate(data_loader, model, device, use_amp=False): |
| 139 | criterion = torch.nn.CrossEntropyLoss() |
| 140 | |
| 141 | metric_logger = utils.MetricLogger(delimiter=" ") |
| 142 | header = 'Test:' |
| 143 | |
| 144 | # switch to evaluation mode |
| 145 | model.eval() |
| 146 | for batch in metric_logger.log_every(data_loader, 10, header): |
| 147 | images = batch[0] |
| 148 | target = batch[-1] |
| 149 | |
| 150 | images = images.to(device, non_blocking=True) |
| 151 | target = target.to(device, non_blocking=True) |
| 152 | |
| 153 | # compute output |
| 154 | if use_amp: |
| 155 | with torch.cuda.amp.autocast(): |
| 156 | output = model(images) |
| 157 | loss = criterion(output, target) |
| 158 | else: |
| 159 | output = model(images) |
| 160 | loss = criterion(output, target) |
| 161 | |
| 162 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) |
| 163 | |
| 164 | batch_size = images.shape[0] |
| 165 | metric_logger.update(loss=loss.item()) |
| 166 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) |
| 167 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) |
| 168 | # gather the stats from all processes |
| 169 | metric_logger.synchronize_between_processes() |
| 170 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' |
| 171 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) |
| 172 | |
| 173 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} |
no test coverage detected