Test generating an answer using the RAG retriever.
(rag_retriever)
| 413 | |
| 414 | # End-to-end functionality test |
| 415 | def test_generate_answer(rag_retriever): |
| 416 | """Test generating an answer using the RAG retriever.""" |
| 417 | # Mock the retrieve method using patch |
| 418 | with patch.object( |
| 419 | rag_retriever, |
| 420 | "retrieve", |
| 421 | return_value=( |
| 422 | np.array([[[0.1] * 8, [0.2] * 8]]), # 8-dimensional embeddings |
| 423 | np.array([[1, 2]]), |
| 424 | [ |
| 425 | { |
| 426 | "text": ["context1", "context2"], |
| 427 | "id": ["doc1", "doc2"], |
| 428 | "title": ["Doc 1", "Doc 2"], |
| 429 | } |
| 430 | ], |
| 431 | ), |
| 432 | ) as mock_retrieve: |
| 433 | # Mock the generator model's generate method |
| 434 | rag_retriever.generator_model.generate = Mock( |
| 435 | return_value=torch.tensor([[1, 2, 3]]) |
| 436 | ) |
| 437 | |
| 438 | # Generate an answer |
| 439 | answer = rag_retriever.generate_answer( |
| 440 | "test query", top_k=2, max_new_tokens=100 |
| 441 | ) |
| 442 | |
| 443 | # Verify the answer |
| 444 | assert isinstance(answer, str) |
| 445 | assert len(answer) > 0 |
| 446 | |
| 447 | # Verify that retrieve was called with correct parameters |
| 448 | mock_retrieve.assert_called_once() |
| 449 | call_args = mock_retrieve.call_args[1] |
| 450 | assert call_args["n_docs"] == 2 |
| 451 | assert call_args["query"] == "test query" |
| 452 | |
| 453 | # Verify that generate was called with correct parameters |
| 454 | rag_retriever.generator_model.generate.assert_called_once() |
| 455 | call_args = rag_retriever.generator_model.generate.call_args[1] |
| 456 | assert call_args["max_new_tokens"] == 100 |
nothing calls this directly
no test coverage detected