Overrides the base `retrieve` method to fetch documents from Feast. This method processes a batch of questions, queries the vector store, and formats the results into the 3-part tuple expected by the RAG model: (document_embeddings, document_ids, document_dictionari
( # type: ignore [override]
self,
question_hidden_states: np.ndarray,
n_docs: int,
query: Optional[str] = None,
)
| 128 | return self._vector_store |
| 129 | |
| 130 | def retrieve( # type: ignore [override] |
| 131 | self, |
| 132 | question_hidden_states: np.ndarray, |
| 133 | n_docs: int, |
| 134 | query: Optional[str] = None, |
| 135 | ) -> tuple[np.ndarray, np.ndarray, list[dict]]: |
| 136 | """ |
| 137 | Overrides the base `retrieve` method to fetch documents from Feast. |
| 138 | |
| 139 | This method processes a batch of questions, queries the vector store, |
| 140 | and formats the results into the 3-part tuple expected by the RAG model: |
| 141 | (document_embeddings, document_ids, document_dictionaries). |
| 142 | |
| 143 | Args: |
| 144 | question_hidden_states (np.ndarray): |
| 145 | Hidden state representation of the question from the encoder. |
| 146 | Expected shape is (batch_size, seq_len, hidden_dim). |
| 147 | n_docs (int): |
| 148 | Number of top documents to retrieve. |
| 149 | query (Optional[str]): |
| 150 | Optional raw query string. If not provided and search_type is "text" or "hybrid", |
| 151 | it will be decoded from question_hidden_states. |
| 152 | |
| 153 | Returns: |
| 154 | Tuple containing: |
| 155 | - retrieved_doc_embeds (np.ndarray): |
| 156 | Embeddings of the retrieved documents with shape (batch_size, n_docs, embed_dim). |
| 157 | - doc_ids (np.ndarray): |
| 158 | Array of document IDs or passage identifiers with shape (batch_size, n_docs). |
| 159 | - doc_dicts (list[dict]): |
| 160 | List of dictionaries containing document metadata and text. |
| 161 | Each dictionary has keys "text", "id", and "title". |
| 162 | """ |
| 163 | batch_size = question_hidden_states.shape[0] |
| 164 | |
| 165 | # Convert the question hidden states into a list of 1D query vectors. |
| 166 | pooled_query_vectors = [] |
| 167 | for i in range(batch_size): |
| 168 | pooled = question_hidden_states[i] |
| 169 | # Perform normalization to create a unit vector. |
| 170 | norm = np.linalg.norm(pooled) |
| 171 | if norm > 0: |
| 172 | pooled = pooled / norm |
| 173 | pooled_query_vectors.append(pooled) |
| 174 | |
| 175 | # Determine embedding dimension for padding |
| 176 | emb_dim = ( |
| 177 | pooled_query_vectors[0].shape[-1] |
| 178 | if pooled_query_vectors and pooled_query_vectors[0] is not None |
| 179 | else self.config.retrieval_vector_size |
| 180 | ) |
| 181 | |
| 182 | # Retrieve documents for each query in batch |
| 183 | batch_embeddings, batch_doc_ids, batch_metadata = [], [], [] |
| 184 | |
| 185 | for i in range(batch_size): |
| 186 | query_vector = pooled_query_vectors[i] |
| 187 | if isinstance(query, list): |