A helper method to generate an answer for a single query string. Args: query: The query to answer top_k: Number of documents to retrieve max_new_tokens: Maximum number of tokens to generate Returns: Generated answer string
(
self, query: str, top_k: int = 5, max_new_tokens: int = 100
)
| 267 | ) |
| 268 | |
| 269 | def generate_answer( |
| 270 | self, query: str, top_k: int = 5, max_new_tokens: int = 100 |
| 271 | ) -> str: |
| 272 | """A helper method to generate an answer for a single query string. |
| 273 | |
| 274 | Args: |
| 275 | query: The query to answer |
| 276 | top_k: Number of documents to retrieve |
| 277 | max_new_tokens: Maximum number of tokens to generate |
| 278 | |
| 279 | Returns: |
| 280 | Generated answer string |
| 281 | """ |
| 282 | if not self.question_encoder or not self.generator_model: |
| 283 | raise ValueError( |
| 284 | "`question_encoder` and `generator_model` must be provided to use `generate_answer`." |
| 285 | ) |
| 286 | torch = get_torch() |
| 287 | inputs = self.question_encoder_tokenizer(query, return_tensors="pt").to( |
| 288 | self.question_encoder.device |
| 289 | ) |
| 290 | question_embeddings = self.question_encoder(**inputs).pooler_output |
| 291 | question_embeddings = ( |
| 292 | question_embeddings.detach().cpu().to(torch.float32).numpy() |
| 293 | ) |
| 294 | _, _, doc_batch = self.retrieve(question_embeddings, n_docs=top_k, query=query) |
| 295 | |
| 296 | contexts = doc_batch[0]["text"] if doc_batch else [] |
| 297 | context_str = "\n\n".join(filter(None, contexts)) |
| 298 | |
| 299 | prompt = f"Context: {context_str}\n\nQuestion: {query}\n\nAnswer:" |
| 300 | |
| 301 | generator_inputs = self.generator_tokenizer(prompt, return_tensors="pt").to( |
| 302 | self.generator_model.device |
| 303 | ) |
| 304 | output_ids = self.generator_model.generate( |
| 305 | **generator_inputs, max_new_tokens=max_new_tokens |
| 306 | ) |
| 307 | |
| 308 | return self.generator_tokenizer.decode(output_ids[0], skip_special_tokens=True) |
| 309 | |
| 310 | def _default_format_document(self, doc: dict[str, Any]) -> str: |
| 311 | """Default document formatting function. |