MCPcopy
hub / github.com/facebookresearch/MetaCLIP / build_text_features

Function build_text_features

clipeval/eval_zeroshot.py:51–71  ·  view source on GitHub ↗
(templates, labels, model, tokenizer, skip_text_projection=False, classnorm=False)

Source from the content-addressed store, hash-verified

49
50@torch.no_grad()
51def build_text_features(templates, labels, model, tokenizer, skip_text_projection=False, classnorm=False):
52 # TODO: add device
53 text_features = []
54 for label in labels:
55 if isinstance(label, list):
56 texts = [t.format(l) for t in templates for l in label]
57 else:
58 texts = [t.format(label) for t in templates]
59
60 texts = tokenizer(texts).to(next(model.parameters()).device, non_blocking=True)
61 class_embeddings = model.encode_text(texts)
62 class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True)
63 class_embeddings = class_embeddings.mean(dim=0)
64 text_features.append(class_embeddings)
65 text_features = torch.stack(text_features, dim=0)
66 mean, std = None, None
67 if classnorm:
68 mean, std = text_features.mean(dim=0)[None, :], text_features.std(dim=0)[None, :]
69 text_features = (text_features - mean) / std
70 text_features = text_features / text_features.norm(dim=-1, keepdim=True)
71 return text_features, mean, std
72
73
74@torch.no_grad()

Callers 1

validate_zeroshotFunction · 0.85

Calls 1

encode_textMethod · 0.45

Tested by

no test coverage detected