MCPcopy
hub / github.com/lm-sys/FastChat / generate_stream_chatglm

Function generate_stream_chatglm

fastchat/model/model_chatglm.py:66–137  ·  view source on GitHub ↗
(
    model,
    tokenizer,
    params,
    device,
    context_len=2048,
    stream_interval=2,
    judge_sent_end=False,
)

Source from the content-addressed store, hash-verified

64
65@torch.inference_mode()
66def generate_stream_chatglm(
67 model,
68 tokenizer,
69 params,
70 device,
71 context_len=2048,
72 stream_interval=2,
73 judge_sent_end=False,
74):
75 prompt = params["prompt"]
76 temperature = float(params.get("temperature", 1.0))
77 repetition_penalty = float(params.get("repetition_penalty", 1.0))
78 top_p = float(params.get("top_p", 1.0))
79 max_new_tokens = int(params.get("max_new_tokens", 256))
80 echo = params.get("echo", True)
81
82 model_type = str(type(model)).lower()
83 if "peft" in model_type:
84 model_type = str(type(model.base_model.model)).lower()
85
86 if "chatglm3" in model_type:
87 message_list = recover_message_list(prompt)
88 inputs = tokenizer.build_chat_input(
89 query=message_list[-1]["content"], history=message_list[:-1], role="user"
90 ).to(model.device)
91 else:
92 inputs = tokenizer([prompt], return_tensors="pt").to(model.device)
93 input_echo_len = len(inputs["input_ids"][0])
94
95 gen_kwargs = {
96 "max_length": max_new_tokens + input_echo_len,
97 "do_sample": True if temperature > 1e-5 else False,
98 "top_p": top_p,
99 "repetition_penalty": repetition_penalty,
100 "logits_processor": [invalid_score_processor],
101 }
102 if temperature > 1e-5:
103 gen_kwargs["temperature"] = temperature
104
105 total_len = 0
106 for total_ids in model.stream_generate(**inputs, **gen_kwargs):
107 total_ids = total_ids.tolist()[0]
108 total_len = len(total_ids)
109 if echo:
110 output_ids = total_ids
111 else:
112 output_ids = total_ids[input_echo_len:]
113 response = tokenizer.decode(output_ids)
114 response = process_response(response)
115
116 yield {
117 "text": response,
118 "usage": {
119 "prompt_tokens": input_echo_len,
120 "completion_tokens": total_len - input_echo_len,
121 "total_tokens": total_len,
122 },
123 "finish_reason": None,

Callers

nothing calls this directly

Calls 3

recover_message_listFunction · 0.85
process_responseFunction · 0.85
toMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…