MCPcopy
hub / github.com/facebookresearch/MetaCLIP / validate_zeroshot

Function validate_zeroshot

clipeval/eval_zeroshot.py:75–127  ·  view source on GitHub ↗
(val_loader, templates, labels, model, tokenizer, is_acc, classnorm=False)

Source from the content-addressed store, hash-verified

73
74@torch.no_grad()
75def validate_zeroshot(val_loader, templates, labels, model, tokenizer, is_acc, classnorm=False):
76 # switch to evaluate mode
77 model.cuda()
78 model.eval()
79
80 total_top1 = 0
81 total_images = 0
82
83 all_outputs = []
84 all_targets = []
85
86 text_features = None
87
88 for samples in val_loader:
89 if text_features is None:
90 print('=> encoding captions')
91 text_features, mean, std = build_text_features(templates, labels, model, tokenizer, classnorm=classnorm)
92
93 if isinstance(samples, tuple) or isinstance(samples, list):
94 images, target = samples[0], samples[1]
95 elif isinstance(samples, dict):
96 images, target = samples["pixel_values"], samples["targets"]
97 else:
98 raise ValueError("unknown sample type", type(samples))
99
100 images = images.cuda(non_blocking=True)
101 target = target.cuda(non_blocking=True)
102
103 # encode images
104 image_features = model.encode_image(images)
105
106 if classnorm:
107 image_features = (image_features - mean) / std
108
109 image_features = image_features / image_features.norm(dim=-1, keepdim=True)
110 # cosine similarity as logits
111 logits_per_image = image_features @ text_features.t()
112 logits_per_image = logits_per_image.cpu()
113 target = target.cpu()
114 if is_acc:
115 # measure accuracy and record loss
116 pred = logits_per_image.argmax(dim=1)
117 correct = pred.eq(target).sum()
118 total_top1 += correct.item()
119 total_images += images.size(0)
120 else:
121 all_outputs.append(logits_per_image)
122 all_targets.append(target)
123
124 if is_acc:
125 return 100 * total_top1 / total_images
126 else:
127 return torch.cat(all_outputs), torch.cat(all_targets)
128
129
130def accuracy(output, target, topk=(1,)):

Callers 2

evaluate_logitsFunction · 0.90
evaluateFunction · 0.85

Calls 2

build_text_featuresFunction · 0.85
encode_imageMethod · 0.80

Tested by

no test coverage detected