MCPcopy
hub / github.com/mlfoundations/open_flamingo / find

Method find

open_flamingo/eval/rices.py:66–96  ·  view source on GitHub ↗

Get the top num_examples most similar examples to the images.

(self, batch, num_examples)

Source from the content-addressed store, hash-verified

64 return features
65
66 def find(self, batch, num_examples):
67 """
68 Get the top num_examples most similar examples to the images.
69 """
70 # Switch to evaluation mode
71 self.model.eval()
72
73 with torch.no_grad():
74 inputs = torch.stack([self.image_processor(image) for image in batch]).to(
75 self.device
76 )
77
78 # Get the feature of the input image
79 query_feature = self.model.encode_image(inputs)
80 query_feature /= query_feature.norm(dim=-1, keepdim=True)
81 query_feature = query_feature.detach().cpu()
82
83 if query_feature.ndim == 1:
84 query_feature = query_feature.unsqueeze(0)
85
86 # Compute the similarity of the input image to the precomputed features
87 similarity = (query_feature @ self.features.T).squeeze()
88
89 if similarity.ndim == 1:
90 similarity = similarity.unsqueeze(0)
91
92 # Get the indices of the 'num_examples' most similar images
93 indices = similarity.argsort(dim=-1, descending=True)[:, :num_examples]
94
95 # Return with the most similar images last
96 return [[self.dataset[i] for i in reversed(row)] for row in indices]

Callers 4

evaluate_captioningFunction · 0.95
evaluate_vqaFunction · 0.95
evaluate_classificationFunction · 0.95
getattr_recursiveFunction · 0.80

Calls

no outgoing calls

Tested by

no test coverage detected