(templates, labels, model, tokenizer, skip_text_projection=False, classnorm=False)
| 49 | |
| 50 | @torch.no_grad() |
| 51 | def build_text_features(templates, labels, model, tokenizer, skip_text_projection=False, classnorm=False): |
| 52 | # TODO: add device |
| 53 | text_features = [] |
| 54 | for label in labels: |
| 55 | if isinstance(label, list): |
| 56 | texts = [t.format(l) for t in templates for l in label] |
| 57 | else: |
| 58 | texts = [t.format(label) for t in templates] |
| 59 | |
| 60 | texts = tokenizer(texts).to(next(model.parameters()).device, non_blocking=True) |
| 61 | class_embeddings = model.encode_text(texts) |
| 62 | class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) |
| 63 | class_embeddings = class_embeddings.mean(dim=0) |
| 64 | text_features.append(class_embeddings) |
| 65 | text_features = torch.stack(text_features, dim=0) |
| 66 | mean, std = None, None |
| 67 | if classnorm: |
| 68 | mean, std = text_features.mean(dim=0)[None, :], text_features.std(dim=0)[None, :] |
| 69 | text_features = (text_features - mean) / std |
| 70 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
| 71 | return text_features, mean, std |
| 72 | |
| 73 | |
| 74 | @torch.no_grad() |
no test coverage detected