MCPcopy
hub / github.com/zai-org/ChatGLM3 / generate_stream_chatglm3

Function generate_stream_chatglm3

openai_api_demo/utils.py:46–119  ·  view source on GitHub ↗
(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict)

Source from the content-addressed store, hash-verified

44
45@torch.inference_mode()
46def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
47 messages = params["messages"]
48 tools = params["tools"]
49 temperature = float(params.get("temperature", 1.0))
50 repetition_penalty = float(params.get("repetition_penalty", 1.0))
51 top_p = float(params.get("top_p", 1.0))
52 max_new_tokens = int(params.get("max_tokens", 256))
53 echo = params.get("echo", True)
54 messages = process_chatglm_messages(messages, tools=tools)
55 query, role = messages[-1]["content"], messages[-1]["role"]
56
57 inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role)
58 inputs = inputs.to(model.device)
59 input_echo_len = len(inputs["input_ids"][0])
60
61 if input_echo_len >= model.config.seq_length:
62 print(f"Input length larger than {model.config.seq_length}")
63
64 eos_token_id = [
65 tokenizer.eos_token_id,
66 tokenizer.get_command("<|user|>"),
67 tokenizer.get_command("<|observation|>")
68 ]
69
70 gen_kwargs = {
71 "max_new_tokens": max_new_tokens,
72 "do_sample": True if temperature > 1e-5 else False,
73 "top_p": top_p,
74 "repetition_penalty": repetition_penalty,
75 "logits_processor": [InvalidScoreLogitsProcessor()],
76 }
77 if temperature > 1e-5:
78 gen_kwargs["temperature"] = temperature
79
80 total_len = 0
81 for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs):
82 total_ids = total_ids.tolist()[0]
83 total_len = len(total_ids)
84 if echo:
85 output_ids = total_ids[:-1]
86 else:
87 output_ids = total_ids[input_echo_len:-1]
88
89 response = tokenizer.decode(output_ids)
90 if response and response[-1] != "�":
91 response, stop_found = apply_stopping_strings(response, ["<|observation|>"])
92
93 yield {
94 "text": response,
95 "usage": {
96 "prompt_tokens": input_echo_len,
97 "completion_tokens": total_len - input_echo_len,
98 "total_tokens": total_len,
99 },
100 "finish_reason": "function_call" if stop_found else None,
101 }
102
103 if stop_found:

Callers 3

predictFunction · 0.90
predict_streamFunction · 0.90
generate_chatglm3Function · 0.70

Calls 3

process_chatglm_messagesFunction · 0.70
apply_stopping_stringsFunction · 0.70

Tested by

no test coverage detected