MCPcopy
hub / github.com/THUDM/LongWriter / predict

Function predict

trans_web_demo.py:64–104  ·  view source on GitHub ↗
(history, prompt, max_length, top_p, temperature)

Source from the content-addressed store, hash-verified

62
63
64def predict(history, prompt, max_length, top_p, temperature):
65 stop = StopOnTokens()
66 messages = []
67 if prompt:
68 messages.append({"role": "system", "content": prompt})
69 for idx, (user_msg, model_msg) in enumerate(history):
70 if prompt and idx == 0:
71 continue
72 if idx == len(history) - 1 and not model_msg:
73 # messages.append({"role": "user", "content": user_msg})
74 query = user_msg
75 break
76 if user_msg:
77 messages.append({"role": "user", "content": user_msg})
78 if model_msg:
79 messages.append({"role": "assistant", "content": model_msg})
80
81 model_inputs = tokenizer.build_chat_input(query, history=messages, role='user').input_ids.to(
82 next(model.parameters()).device)
83 streamer = TextIteratorStreamer(tokenizer, timeout=600, skip_prompt=True, skip_special_tokens=True)
84 eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
85 tokenizer.get_command("<|observation|>")]
86 generate_kwargs = {
87 "input_ids": model_inputs,
88 "streamer": streamer,
89 "max_new_tokens": max_length,
90 "do_sample": True,
91 "top_p": top_p,
92 "temperature": temperature,
93 "stopping_criteria": StoppingCriteriaList([stop]),
94 "repetition_penalty": 1,
95 "eos_token_id": eos_token_id,
96 }
97 t = Thread(target=model.generate, kwargs=generate_kwargs)
98 t.start()
99 for new_token in streamer:
100 if new_token and '<|user|>' in new_token:
101 new_token = new_token.split('<|user|>')[0]
102 if new_token:
103 history[-1][1] += new_token
104 yield history
105
106
107with gr.Blocks() as demo:

Callers

nothing calls this directly

Calls 3

StopOnTokensClass · 0.85
build_chat_inputMethod · 0.80
get_commandMethod · 0.80

Tested by

no test coverage detected