(request: Request, httpserver_manager: HttpServerManager)
| 58 | |
| 59 | |
| 60 | async def tgi_generate_impl(request: Request, httpserver_manager: HttpServerManager) -> Response: |
| 61 | |
| 62 | request_dict = await request.json() |
| 63 | prompt = request_dict.pop("inputs") |
| 64 | num_beam = request_dict.get("num_beam", 1) |
| 65 | sample_params_dict = format_tgi_params(request_dict["parameters"], num_beam) |
| 66 | return_details = sample_params_dict.pop("return_details", False) |
| 67 | sampling_params = SamplingParams() |
| 68 | sampling_params.init(tokenizer=httpserver_manager.tokenizer, **sample_params_dict) |
| 69 | sampling_params.verify() |
| 70 | multimodal_params_dict = request_dict.get("multimodal_params", {}) |
| 71 | multimodal_params = MultimodalParams(**multimodal_params_dict) |
| 72 | |
| 73 | results_generator = httpserver_manager.generate(prompt, sampling_params, multimodal_params, request=request) |
| 74 | |
| 75 | # Non-streaming case |
| 76 | final_output_dict = collections.defaultdict(list) |
| 77 | count_output_tokens_dict = collections.defaultdict(lambda: 0) |
| 78 | tokens_dict = collections.defaultdict(list) |
| 79 | finish_status_dict = {} |
| 80 | prompt_logprobs = None |
| 81 | prompt_token_ids = None |
| 82 | is_first_metadata = True |
| 83 | best_score = -float("inf") |
| 84 | best_sub_id = 0 |
| 85 | async for sub_req_id, request_output, metadata, finish_status in results_generator: |
| 86 | # when set "--return_all_prompt_logprobs", the first token metadata will contains |
| 87 | # prompt_logprobs and prompt_token_ids |
| 88 | if is_first_metadata: |
| 89 | prompt_logprobs = metadata.get("prompt_logprobs", None) |
| 90 | prompt_token_ids = metadata.get("prompt_token_ids", None) |
| 91 | if prompt_logprobs is not None: |
| 92 | del metadata["prompt_logprobs"] |
| 93 | if prompt_token_ids is not None: |
| 94 | del metadata["prompt_token_ids"] |
| 95 | is_first_metadata = False |
| 96 | |
| 97 | count_output_tokens_dict[sub_req_id] += 1 |
| 98 | final_output_dict[sub_req_id].append(request_output) |
| 99 | if return_details: |
| 100 | metadata["text"] = request_output |
| 101 | tokens_dict[sub_req_id].append(metadata) |
| 102 | if finish_status.is_finished(): |
| 103 | finish_status_dict[sub_req_id] = finish_status |
| 104 | if metadata["cumlogprob"] > best_score: |
| 105 | best_score = metadata["cumlogprob"] |
| 106 | best_sub_id = sub_req_id |
| 107 | |
| 108 | ret = None |
| 109 | beam_sequences = [] |
| 110 | for sub_id in list(final_output_dict.keys()): |
| 111 | if return_details: |
| 112 | beam_ret = { |
| 113 | "generated_text": "".join(final_output_dict[sub_id]), |
| 114 | "finish_reason": finish_status_dict[sub_id].get_finish_reason(), |
| 115 | "generated_tokens": count_output_tokens_dict[sub_id], |
| 116 | "logprob": tokens_dict[sub_id][-1]["cumlogprob"], |
| 117 | } |
nothing calls this directly
no test coverage detected