(
model,
tokenizer,
params,
device,
context_len=2048,
stream_interval=2,
judge_sent_end=False,
)
| 64 | |
| 65 | @torch.inference_mode() |
| 66 | def generate_stream_chatglm( |
| 67 | model, |
| 68 | tokenizer, |
| 69 | params, |
| 70 | device, |
| 71 | context_len=2048, |
| 72 | stream_interval=2, |
| 73 | judge_sent_end=False, |
| 74 | ): |
| 75 | prompt = params["prompt"] |
| 76 | temperature = float(params.get("temperature", 1.0)) |
| 77 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
| 78 | top_p = float(params.get("top_p", 1.0)) |
| 79 | max_new_tokens = int(params.get("max_new_tokens", 256)) |
| 80 | echo = params.get("echo", True) |
| 81 | |
| 82 | model_type = str(type(model)).lower() |
| 83 | if "peft" in model_type: |
| 84 | model_type = str(type(model.base_model.model)).lower() |
| 85 | |
| 86 | if "chatglm3" in model_type: |
| 87 | message_list = recover_message_list(prompt) |
| 88 | inputs = tokenizer.build_chat_input( |
| 89 | query=message_list[-1]["content"], history=message_list[:-1], role="user" |
| 90 | ).to(model.device) |
| 91 | else: |
| 92 | inputs = tokenizer([prompt], return_tensors="pt").to(model.device) |
| 93 | input_echo_len = len(inputs["input_ids"][0]) |
| 94 | |
| 95 | gen_kwargs = { |
| 96 | "max_length": max_new_tokens + input_echo_len, |
| 97 | "do_sample": True if temperature > 1e-5 else False, |
| 98 | "top_p": top_p, |
| 99 | "repetition_penalty": repetition_penalty, |
| 100 | "logits_processor": [invalid_score_processor], |
| 101 | } |
| 102 | if temperature > 1e-5: |
| 103 | gen_kwargs["temperature"] = temperature |
| 104 | |
| 105 | total_len = 0 |
| 106 | for total_ids in model.stream_generate(**inputs, **gen_kwargs): |
| 107 | total_ids = total_ids.tolist()[0] |
| 108 | total_len = len(total_ids) |
| 109 | if echo: |
| 110 | output_ids = total_ids |
| 111 | else: |
| 112 | output_ids = total_ids[input_echo_len:] |
| 113 | response = tokenizer.decode(output_ids) |
| 114 | response = process_response(response) |
| 115 | |
| 116 | yield { |
| 117 | "text": response, |
| 118 | "usage": { |
| 119 | "prompt_tokens": input_echo_len, |
| 120 | "completion_tokens": total_len - input_echo_len, |
| 121 | "total_tokens": total_len, |
| 122 | }, |
| 123 | "finish_reason": None, |
nothing calls this directly
no test coverage detected
searching dependent graphs…