Query the top_k related chunks from the specified collections. :param collection_ids: the collection ids. :param top_k: the number of chunks to query. :param max_tokens: the maximum number of tokens in the chunks. :param score_threshold: the minimum score threshold to return the
(
collection_ids: List[str],
top_k: int,
max_tokens: Optional[int],
score_threshold: Optional[float],
query_text: str,
)
| 14 | |
| 15 | |
| 16 | async def query_chunks( |
| 17 | collection_ids: List[str], |
| 18 | top_k: int, |
| 19 | max_tokens: Optional[int], |
| 20 | score_threshold: Optional[float], |
| 21 | query_text: str, |
| 22 | ) -> List[Chunk]: |
| 23 | """ |
| 24 | Query the top_k related chunks from the specified collections. |
| 25 | :param collection_ids: the collection ids. |
| 26 | :param top_k: the number of chunks to query. |
| 27 | :param max_tokens: the maximum number of tokens in the chunks. |
| 28 | :param score_threshold: the minimum score threshold to return the chunks. |
| 29 | :param query_text: the query text. |
| 30 | :return: the created record |
| 31 | """ |
| 32 | from app.operators import collection_ops |
| 33 | |
| 34 | # fetch all collections |
| 35 | collections = [] |
| 36 | for collection_id in collection_ids: |
| 37 | # currently, raise error when collection is not found |
| 38 | collection: Collection = await collection_ops.get(collection_id=collection_id) |
| 39 | collections.append(collection) |
| 40 | |
| 41 | # check all collections have the same embedding model |
| 42 | embedding_model_ids = set([collection.embedding_model_id for collection in collections]) |
| 43 | if len(embedding_model_ids) > 1: |
| 44 | raise_http_error( |
| 45 | ErrorCode.REQUEST_VALIDATION_ERROR, message="The specified collections use different embedding models." |
| 46 | ) |
| 47 | |
| 48 | # validate model |
| 49 | embedding_model: Model = await get_model(collections[0].embedding_model_id) |
| 50 | |
| 51 | # compute query vector |
| 52 | query_vector = await embed_query( |
| 53 | query=query_text, |
| 54 | embedding_model=embedding_model, |
| 55 | embedding_size=collections[0].embedding_size, |
| 56 | ) |
| 57 | |
| 58 | # query related chunks |
| 59 | chunks = await db_chunk.query_chunks( |
| 60 | collections=collections, |
| 61 | top_k=top_k, |
| 62 | max_tokens=max_tokens, |
| 63 | score_threshold=score_threshold, |
| 64 | query_vector=query_vector, |
| 65 | ) |
| 66 | return chunks |
no test coverage detected