(
query: str,
history: List[List[str]],
model_id: str,
stop_words: List[str],
gen_kwargs: Dict,
system: str,
)
| 465 | |
| 466 | |
| 467 | async def predict( |
| 468 | query: str, |
| 469 | history: List[List[str]], |
| 470 | model_id: str, |
| 471 | stop_words: List[str], |
| 472 | gen_kwargs: Dict, |
| 473 | system: str, |
| 474 | ): |
| 475 | global model, tokenizer |
| 476 | choice_data = ChatCompletionResponseStreamChoice( |
| 477 | index=0, delta=DeltaMessage(role='assistant'), finish_reason=None) |
| 478 | chunk = ChatCompletionResponse(model=model_id, |
| 479 | choices=[choice_data], |
| 480 | object='chat.completion.chunk') |
| 481 | yield '{}'.format(_dump_json(chunk, exclude_unset=True)) |
| 482 | |
| 483 | current_length = 0 |
| 484 | stop_words_ids = [tokenizer.encode(s) |
| 485 | for s in stop_words] if stop_words else None |
| 486 | |
| 487 | delay_token_num = max([len(x) for x in stop_words]) if stop_words_ids else 0 |
| 488 | response_generator = model.chat_stream(tokenizer, |
| 489 | query, |
| 490 | history=history, |
| 491 | stop_words_ids=stop_words_ids, |
| 492 | system=system, |
| 493 | **gen_kwargs) |
| 494 | for _new_response in response_generator: |
| 495 | if len(_new_response) <= delay_token_num: |
| 496 | continue |
| 497 | new_response = _new_response[:-delay_token_num] if delay_token_num else _new_response |
| 498 | |
| 499 | if len(new_response) == current_length: |
| 500 | continue |
| 501 | |
| 502 | new_text = new_response[current_length:] |
| 503 | current_length = len(new_response) |
| 504 | |
| 505 | choice_data = ChatCompletionResponseStreamChoice( |
| 506 | index=0, delta=DeltaMessage(content=new_text), finish_reason=None) |
| 507 | chunk = ChatCompletionResponse(model=model_id, |
| 508 | choices=[choice_data], |
| 509 | object='chat.completion.chunk') |
| 510 | yield '{}'.format(_dump_json(chunk, exclude_unset=True)) |
| 511 | |
| 512 | if current_length != len(_new_response): |
| 513 | # Determine whether to print the delay tokens |
| 514 | delayed_text = _new_response[current_length:] |
| 515 | new_text = trim_stop_words(delayed_text, stop_words) |
| 516 | if len(new_text) > 0: |
| 517 | choice_data = ChatCompletionResponseStreamChoice( |
| 518 | index=0, delta=DeltaMessage(content=new_text), finish_reason=None) |
| 519 | chunk = ChatCompletionResponse(model=model_id, |
| 520 | choices=[choice_data], |
| 521 | object='chat.completion.chunk') |
| 522 | yield '{}'.format(_dump_json(chunk, exclude_unset=True)) |
| 523 | |
| 524 | choice_data = ChatCompletionResponseStreamChoice(index=0, |
no test coverage detected