(history, prompt, max_length, top_p, temperature)
| 62 | |
| 63 | |
| 64 | def predict(history, prompt, max_length, top_p, temperature): |
| 65 | stop = StopOnTokens() |
| 66 | messages = [] |
| 67 | if prompt: |
| 68 | messages.append({"role": "system", "content": prompt}) |
| 69 | for idx, (user_msg, model_msg) in enumerate(history): |
| 70 | if prompt and idx == 0: |
| 71 | continue |
| 72 | if idx == len(history) - 1 and not model_msg: |
| 73 | # messages.append({"role": "user", "content": user_msg}) |
| 74 | query = user_msg |
| 75 | break |
| 76 | if user_msg: |
| 77 | messages.append({"role": "user", "content": user_msg}) |
| 78 | if model_msg: |
| 79 | messages.append({"role": "assistant", "content": model_msg}) |
| 80 | |
| 81 | model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to( |
| 82 | next(model.parameters()).device) |
| 83 | streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True, skip_special_tokens=True) |
| 84 | eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), |
| 85 | tokenizer.get_command("<|observation|>")] |
| 86 | generate_kwargs = { |
| 87 | "input_ids": model_inputs, |
| 88 | "streamer": streamer, |
| 89 | "max_new_tokens": max_length, |
| 90 | "do_sample": True, |
| 91 | "top_p": top_p, |
| 92 | "temperature": temperature, |
| 93 | "stopping_criteria": StoppingCriteriaList([stop]), |
| 94 | "repetition_penalty": 1, |
| 95 | "eos_token_id": eos_token_id, |
| 96 | } |
| 97 | t = Thread(target=model.generate, kwargs=generate_kwargs) |
| 98 | t.start() |
| 99 | for new_token in streamer: |
| 100 | if new_token and '<|user|>' in new_token: |
| 101 | new_token = new_token.split('<|user|>')[0] |
| 102 | if new_token: |
| 103 | history[-1][1] += new_token |
| 104 | yield history |
| 105 | |
| 106 | |
| 107 | with gr.Blocks() as demo: |
nothing calls this directly
no test coverage detected