Creates a completion for the chat message
(request: ChatCompletionRequest)
| 411 | |
| 412 | @app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)]) |
| 413 | async def create_chat_completion(request: ChatCompletionRequest): |
| 414 | """Creates a completion for the chat message""" |
| 415 | error_check_ret = await check_model(request) |
| 416 | if error_check_ret is not None: |
| 417 | return error_check_ret |
| 418 | error_check_ret = check_requests(request) |
| 419 | if error_check_ret is not None: |
| 420 | return error_check_ret |
| 421 | |
| 422 | worker_addr = await get_worker_address(request.model) |
| 423 | |
| 424 | gen_params = await get_gen_params( |
| 425 | request.model, |
| 426 | worker_addr, |
| 427 | request.messages, |
| 428 | temperature=request.temperature, |
| 429 | top_p=request.top_p, |
| 430 | top_k=request.top_k, |
| 431 | presence_penalty=request.presence_penalty, |
| 432 | frequency_penalty=request.frequency_penalty, |
| 433 | max_tokens=request.max_tokens, |
| 434 | echo=False, |
| 435 | stop=request.stop, |
| 436 | ) |
| 437 | |
| 438 | max_new_tokens, error_check_ret = await check_length( |
| 439 | request, |
| 440 | gen_params["prompt"], |
| 441 | gen_params["max_new_tokens"], |
| 442 | worker_addr, |
| 443 | ) |
| 444 | |
| 445 | if error_check_ret is not None: |
| 446 | return error_check_ret |
| 447 | |
| 448 | gen_params["max_new_tokens"] = max_new_tokens |
| 449 | |
| 450 | if request.stream: |
| 451 | generator = chat_completion_stream_generator( |
| 452 | request.model, gen_params, request.n, worker_addr |
| 453 | ) |
| 454 | return StreamingResponse(generator, media_type="text/event-stream") |
| 455 | |
| 456 | choices = [] |
| 457 | chat_completions = [] |
| 458 | for i in range(request.n): |
| 459 | content = asyncio.create_task(generate_completion(gen_params, worker_addr)) |
| 460 | chat_completions.append(content) |
| 461 | try: |
| 462 | all_tasks = await asyncio.gather(*chat_completions) |
| 463 | except Exception as e: |
| 464 | return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) |
| 465 | usage = UsageInfo() |
| 466 | for i, content in enumerate(all_tasks): |
| 467 | if isinstance(content, str): |
| 468 | content = json.loads(content) |
| 469 | |
| 470 | if content["error_code"] != 0: |
nothing calls this directly
no test coverage detected
searching dependent graphs…