Creates embeddings for the text
(request: EmbeddingsRequest, model_name: str = None)
| 707 | @app.post("/v1/embeddings", dependencies=[Depends(check_api_key)]) |
| 708 | @app.post("/v1/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)]) |
| 709 | async def create_embeddings(request: EmbeddingsRequest, model_name: str = None): |
| 710 | """Creates embeddings for the text""" |
| 711 | if request.model is None: |
| 712 | request.model = model_name |
| 713 | error_check_ret = await check_model(request) |
| 714 | if error_check_ret is not None: |
| 715 | return error_check_ret |
| 716 | |
| 717 | request.input = process_input(request.model, request.input) |
| 718 | |
| 719 | data = [] |
| 720 | token_num = 0 |
| 721 | batch_size = WORKER_API_EMBEDDING_BATCH_SIZE |
| 722 | batches = [ |
| 723 | request.input[i : min(i + batch_size, len(request.input))] |
| 724 | for i in range(0, len(request.input), batch_size) |
| 725 | ] |
| 726 | for num_batch, batch in enumerate(batches): |
| 727 | payload = { |
| 728 | "model": request.model, |
| 729 | "input": batch, |
| 730 | "encoding_format": request.encoding_format, |
| 731 | } |
| 732 | embedding = await get_embedding(payload) |
| 733 | if "error_code" in embedding and embedding["error_code"] != 0: |
| 734 | return create_error_response(embedding["error_code"], embedding["text"]) |
| 735 | data += [ |
| 736 | { |
| 737 | "object": "embedding", |
| 738 | "embedding": emb, |
| 739 | "index": num_batch * batch_size + i, |
| 740 | } |
| 741 | for i, emb in enumerate(embedding["embedding"]) |
| 742 | ] |
| 743 | token_num += embedding["token_num"] |
| 744 | return EmbeddingsResponse( |
| 745 | data=data, |
| 746 | model=request.model, |
| 747 | usage=UsageInfo( |
| 748 | prompt_tokens=token_num, |
| 749 | total_tokens=token_num, |
| 750 | completion_tokens=None, |
| 751 | ), |
| 752 | ).dict(exclude_none=True) |
| 753 | |
| 754 | |
| 755 | async def get_embedding(payload: Dict[str, Any]): |
nothing calls this directly
no test coverage detected
searching dependent graphs…