(d, val_loader, templates, labels, model, tokenizer, classnorm=False)
| 18 | |
| 19 | |
| 20 | def evaluate_logits(d, val_loader, templates, labels, model, tokenizer, classnorm=False): |
| 21 | print('Evaluating {}'.format(d)) |
| 22 | |
| 23 | outputs = validate_zeroshot(val_loader, templates, labels, model, tokenizer, False, classnorm) |
| 24 | |
| 25 | if d in ['FGVCAircraft', 'OxfordPets', 'Caltech101', 'Flowers102']: |
| 26 | metric = mean_per_class(*outputs) |
| 27 | elif d == 'Kinetics700': |
| 28 | top1, top5 = accuracy(*outputs, topk=(1, 5)) |
| 29 | metric = (top1 + top5) / 2 |
| 30 | metric = metric.item() |
| 31 | elif d == 'HatefulMemes': |
| 32 | metric = roc_auc(*outputs) |
| 33 | else: |
| 34 | pred = outputs[0].argmax(dim=1) |
| 35 | correct = pred.eq(outputs[1]).sum() |
| 36 | metric = correct.item() / float(pred.size(0)) * 100.0 |
| 37 | |
| 38 | return metric, outputs |
| 39 | |
| 40 | |
| 41 | @torch.no_grad() |
no test coverage detected