(request: Request, httpserver_manager: HttpServerManager)
| 144 | |
| 145 | |
| 146 | async def tgi_generate_stream_impl(request: Request, httpserver_manager: HttpServerManager) -> Response: |
| 147 | |
| 148 | request_dict = await request.json() |
| 149 | prompt = request_dict.pop("inputs") |
| 150 | sample_params_dict = format_tgi_params(request_dict["parameters"]) |
| 151 | return_details = sample_params_dict.pop("return_details", False) |
| 152 | sampling_params = SamplingParams() |
| 153 | sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) |
| 154 | sampling_params.verify() |
| 155 | if sampling_params.best_of != 1: |
| 156 | raise Exception("stream api only support best_of == 1") |
| 157 | multimodal_params_dict = request_dict.get("multimodal_params", {}) |
| 158 | multimodal_params = MultimodalParams(**multimodal_params_dict) |
| 159 | |
| 160 | results_generator = httpserver_manager.generate(prompt, sampling_params, multimodal_params, request=request) |
| 161 | |
| 162 | # Streaming case |
| 163 | async def stream_results() -> AsyncGenerator[bytes, None]: |
| 164 | final_output = [] |
| 165 | async for _, request_output, metadata, finish_status in results_generator: |
| 166 | ret = { |
| 167 | "token": { |
| 168 | "id": metadata.get("id", None), |
| 169 | "text": request_output, |
| 170 | "logprob": metadata.get("logprob", None), |
| 171 | "special": metadata.get("special", False), |
| 172 | "count_output_tokens": metadata.get("count_output_tokens", 0), |
| 173 | "prompt_tokens": metadata.get("prompt_tokens", 0), |
| 174 | }, |
| 175 | "generated_text": None, |
| 176 | "finished": finish_status.is_finished(), |
| 177 | "finish_reason": finish_status.get_finish_reason(), |
| 178 | "details": None, |
| 179 | } |
| 180 | final_output.append(request_output) |
| 181 | if ret["finished"]: |
| 182 | ret["generated_text"] = "".join(final_output) |
| 183 | if return_details: |
| 184 | ret["details"] = { |
| 185 | "generated_tokens": len(final_output), |
| 186 | "finish_reason": finish_status.get_finish_reason(), |
| 187 | "prompt_tokens": metadata.get("prompt_tokens", 0), |
| 188 | } |
| 189 | |
| 190 | yield ("data:" + json.dumps(ret, ensure_ascii=False) + "\n\n").encode("utf-8") |
| 191 | |
| 192 | background_tasks = BackgroundTasks() |
| 193 | return StreamingResponse(stream_results(), media_type="text/event-stream", background=background_tasks) |
nothing calls this directly
no test coverage detected