(val_loader, templates, labels, model, tokenizer, is_acc, classnorm=False)
| 73 | |
| 74 | @torch.no_grad() |
| 75 | def validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc, classnorm=False): |
| 76 | # switch to evaluate mode |
| 77 | model.cuda() |
| 78 | model.eval() |
| 79 | |
| 80 | total_top1 = 0 |
| 81 | total_images = 0 |
| 82 | |
| 83 | all_outputs = [] |
| 84 | all_targets = [] |
| 85 | |
| 86 | text_features = None |
| 87 | |
| 88 | for samples in val_loader: |
| 89 | if text_features is None: |
| 90 | print('=> encoding captions') |
| 91 | text_features, mean, std = build_text_features(templates, labels, model, tokenizer, classnorm=classnorm) |
| 92 | |
| 93 | if isinstance(samples, tuple) or isinstance(samples, list): |
| 94 | images, target = samples[0], samples[1] |
| 95 | elif isinstance(samples, dict): |
| 96 | images, target = samples["pixel_values"], samples["targets"] |
| 97 | else: |
| 98 | raise ValueError("unknown sample type", type(samples)) |
| 99 | |
| 100 | images = images.cuda(non_blocking=True) |
| 101 | target = target.cuda(non_blocking=True) |
| 102 | |
| 103 | # encode images |
| 104 | image_features = model.encode_image(images) |
| 105 | |
| 106 | if classnorm: |
| 107 | image_features = (image_features - mean) / std |
| 108 | |
| 109 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
| 110 | # cosine similarity as logits |
| 111 | logits_per_image = image_features @ text_features.t() |
| 112 | logits_per_image = logits_per_image.cpu() |
| 113 | target = target.cpu() |
| 114 | if is_acc: |
| 115 | # measure accuracy and record loss |
| 116 | pred = logits_per_image.argmax(dim=1) |
| 117 | correct = pred.eq(target).sum() |
| 118 | total_top1 += correct.item() |
| 119 | total_images += images.size(0) |
| 120 | else: |
| 121 | all_outputs.append(logits_per_image) |
| 122 | all_targets.append(target) |
| 123 | |
| 124 | if is_acc: |
| 125 | return 100 * total_top1 / total_images |
| 126 | else: |
| 127 | return torch.cat(all_outputs), torch.cat(all_targets) |
| 128 | |
| 129 | |
| 130 | def accuracy(output, target, topk=(1,)): |
no test coverage detected