Get the top num_examples most similar examples to the images.
(self, batch, num_examples)
| 64 | return features |
| 65 | |
| 66 | def find(self, batch, num_examples): |
| 67 | """ |
| 68 | Get the top num_examples most similar examples to the images. |
| 69 | """ |
| 70 | # Switch to evaluation mode |
| 71 | self.model.eval() |
| 72 | |
| 73 | with torch.no_grad(): |
| 74 | inputs = torch.stack([self.image_processor(image) for image in batch]).to( |
| 75 | self.device |
| 76 | ) |
| 77 | |
| 78 | # Get the feature of the input image |
| 79 | query_feature = self.model.encode_image(inputs) |
| 80 | query_feature /= query_feature.norm(dim=-1, keepdim=True) |
| 81 | query_feature = query_feature.detach().cpu() |
| 82 | |
| 83 | if query_feature.ndim == 1: |
| 84 | query_feature = query_feature.unsqueeze(0) |
| 85 | |
| 86 | # Compute the similarity of the input image to the precomputed features |
| 87 | similarity = (query_feature @ self.features.T).squeeze() |
| 88 | |
| 89 | if similarity.ndim == 1: |
| 90 | similarity = similarity.unsqueeze(0) |
| 91 | |
| 92 | # Get the indices of the 'num_examples' most similar images |
| 93 | indices = similarity.argsort(dim=-1, descending=True)[:, :num_examples] |
| 94 | |
| 95 | # Return with the most similar images last |
| 96 | return [[self.dataset[i] for i in reversed(row)] for row in indices] |
no outgoing calls
no test coverage detected