(d, val_loader, templates, labels, model, tokenizer, classnorm=False)
| 27 | |
| 28 | |
| 29 | def evaluate(d, val_loader, templates, labels, model, tokenizer, classnorm=False): |
| 30 | print('Evaluating {}'.format(d)) |
| 31 | |
| 32 | is_acc = d not in ['FGVCAircraft', 'OxfordPets', 'Caltech101', 'Flowers102', 'Kinetics700', 'HatefulMemes'] |
| 33 | |
| 34 | acc_or_outputs = validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc, classnorm) |
| 35 | |
| 36 | if d in ['FGVCAircraft', 'OxfordPets', 'Caltech101', 'Flowers102']: |
| 37 | metric = mean_per_class(*acc_or_outputs) |
| 38 | elif d == 'Kinetics700': |
| 39 | top1, top5 = accuracy(*acc_or_outputs, topk=(1, 5)) |
| 40 | metric = (top1 + top5) / 2 |
| 41 | metric = metric.item() |
| 42 | elif d == 'HatefulMemes': |
| 43 | metric = roc_auc(*acc_or_outputs) |
| 44 | else: |
| 45 | metric = acc_or_outputs |
| 46 | |
| 47 | return metric |
| 48 | |
| 49 | |
| 50 | @torch.no_grad() |
no test coverage detected