(
request: Request,
# chat_request: CompletionCreateParams, # TODO: lots' fo weirdness here, just using raw json
# ctx: Any = Depends(weave_context),
)
| 110 | tags=["openui/chat"], |
| 111 | ) |
| 112 | async def chat_completions( |
| 113 | request: Request, |
| 114 | # chat_request: CompletionCreateParams, # TODO: lots' fo weirdness here, just using raw json |
| 115 | # ctx: Any = Depends(weave_context), |
| 116 | ): |
| 117 | if request.session.get("user_id") is None: |
| 118 | raise HTTPException(status_code=401, detail="Login required to use OpenUI") |
| 119 | user_id = request.session["user_id"] |
| 120 | yesterday = datetime.now() - timedelta(days=1) |
| 121 | tokens = Usage.tokens_since(user_id, yesterday.date()) |
| 122 | if config.ENV == config.Env.PROD and tokens > config.MAX_TOKENS: |
| 123 | raise HTTPException( |
| 124 | status_code=429, |
| 125 | detail="You've exceeded our usage quota, come back tomorrow to generate more UI.", |
| 126 | ) |
| 127 | try: |
| 128 | data = await request.json() # chat_request.model_dump(exclude_unset=True) |
| 129 | input_tokens = count_tokens(data["messages"]) |
| 130 | # TODO: we always assume 4096 max tokens (random fudge factor here) |
| 131 | data["max_tokens"] = 4096 - input_tokens - 20 |
| 132 | # TODO: refactor all these blocks into one once Ollama supports vision |
| 133 | # OpenAI Models |
| 134 | if data.get("model").startswith("gpt"): |
| 135 | if data["model"] == "gpt-4" or data["model"] == "gpt-4-32k": |
| 136 | raise HTTPException(status=400, data="Model not supported") |
| 137 | response: AsyncStream[ |
| 138 | ChatCompletionChunk |
| 139 | ] = await openai.chat.completions.create( |
| 140 | **data, |
| 141 | ) |
| 142 | # gpt-4 tokens are 20x more expensive |
| 143 | multiplier = 20 if "gpt-4" in data["model"] else 1 |
| 144 | return StreamingResponse( |
| 145 | openai_stream_generator(response, input_tokens, user_id, multiplier), |
| 146 | media_type="text/event-stream", |
| 147 | ) |
| 148 | # Groq Models |
| 149 | elif data.get("model").startswith("groq/"): |
| 150 | data["model"] = data["model"].replace("groq/", "") |
| 151 | if groq is None: |
| 152 | raise HTTPException(status=500, detail="Groq API key is not set.") |
| 153 | response: AsyncStream[ |
| 154 | ChatCompletionChunk |
| 155 | ] = await groq.chat.completions.create( |
| 156 | **data, |
| 157 | ) |
| 158 | return StreamingResponse( |
| 159 | openai_stream_generator(response, input_tokens, user_id, 1), |
| 160 | media_type="text/event-stream", |
| 161 | ) |
| 162 | # Litellm Models |
| 163 | elif data.get("model").startswith("litellm/"): |
| 164 | data["model"] = data["model"].replace("litellm/", "") |
| 165 | if litellm is None: |
| 166 | raise HTTPException(status=500, detail="LiteLLM API key is not set.") |
| 167 | response: AsyncStream[ |
| 168 | ChatCompletionChunk |
| 169 | ] = await litellm.chat.completions.create( |
nothing calls this directly
no test coverage detected