Test retrieving documents using the RAG retriever.
(rag_retriever)
| 392 | |
| 393 | |
| 394 | def test_retrieve_documents(rag_retriever): |
| 395 | """Test retrieving documents using the RAG retriever.""" |
| 396 | # Create mock question hidden states with 8 dimensions to match test data |
| 397 | question_hidden_states = np.random.rand( |
| 398 | 1, 8, 8 |
| 399 | ) # (batch_size, seq_len, hidden_dim) |
| 400 | |
| 401 | # Mock the retrieve method |
| 402 | doc_embeds, doc_ids, doc_dicts = rag_retriever.retrieve( |
| 403 | question_hidden_states=question_hidden_states, n_docs=2, query="test query" |
| 404 | ) |
| 405 | |
| 406 | # Verify the results |
| 407 | assert doc_embeds.shape == (1, 2, 8) # (batch_size, n_docs, embedding_dim) |
| 408 | assert len(doc_ids) == 1 # One batch |
| 409 | assert len(doc_dicts) == 1 # One batch |
| 410 | assert len(doc_dicts[0]["text"]) == 2 # Two documents |
| 411 | assert len(doc_dicts[0]["id"]) == 2 # Two document IDs |
| 412 | |
| 413 | |
| 414 | # End-to-end functionality test |