| 59 | return text_features |
| 60 | |
| 61 | def get_text_features_list(self, texts, train=False): |
| 62 | if train: |
| 63 | text_inputs = torch.cat([clip.tokenize(c) |
| 64 | for c in texts]).to(self.device) |
| 65 | text_features = self.model.encode_text(text_inputs) |
| 66 | else: |
| 67 | with torch.no_grad(): |
| 68 | text_inputs = torch.cat([clip.tokenize(c) |
| 69 | for c in texts]).to(self.device) |
| 70 | text_features = self.model.encode_text(text_inputs) |
| 71 | |
| 72 | return text_features |
| 73 | |
| 74 | def get_similarity(self, image_features, text_features): |
| 75 | image_features /= image_features.norm(dim=-1, keepdim=True) |