(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict)
| 44 | |
| 45 | @torch.inference_mode() |
| 46 | def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): |
| 47 | messages = params["messages"] |
| 48 | tools = params["tools"] |
| 49 | temperature = float(params.get("temperature", 1.0)) |
| 50 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
| 51 | top_p = float(params.get("top_p", 1.0)) |
| 52 | max_new_tokens = int(params.get("max_tokens", 256)) |
| 53 | echo = params.get("echo", True) |
| 54 | messages = process_chatglm_messages(messages, tools=tools) |
| 55 | query, role = messages[-1]["content"], messages[-1]["role"] |
| 56 | |
| 57 | inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role) |
| 58 | inputs = inputs.to(model.device) |
| 59 | input_echo_len = len(inputs["input_ids"][0]) |
| 60 | |
| 61 | if input_echo_len >= model.config.seq_length: |
| 62 | print(f"Input length larger than {model.config.seq_length}") |
| 63 | |
| 64 | eos_token_id = [ |
| 65 | tokenizer.eos_token_id, |
| 66 | tokenizer.get_command("<|user|>"), |
| 67 | tokenizer.get_command("<|observation|>") |
| 68 | ] |
| 69 | |
| 70 | gen_kwargs = { |
| 71 | "max_new_tokens": max_new_tokens, |
| 72 | "do_sample": True if temperature > 1e-5 else False, |
| 73 | "top_p": top_p, |
| 74 | "repetition_penalty": repetition_penalty, |
| 75 | "logits_processor": [InvalidScoreLogitsProcessor()], |
| 76 | } |
| 77 | if temperature > 1e-5: |
| 78 | gen_kwargs["temperature"] = temperature |
| 79 | |
| 80 | total_len = 0 |
| 81 | for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs): |
| 82 | total_ids = total_ids.tolist()[0] |
| 83 | total_len = len(total_ids) |
| 84 | if echo: |
| 85 | output_ids = total_ids[:-1] |
| 86 | else: |
| 87 | output_ids = total_ids[input_echo_len:-1] |
| 88 | |
| 89 | response = tokenizer.decode(output_ids) |
| 90 | if response and response[-1] != "�": |
| 91 | response, stop_found = apply_stopping_strings(response, ["<|observation|>"]) |
| 92 | |
| 93 | yield { |
| 94 | "text": response, |
| 95 | "usage": { |
| 96 | "prompt_tokens": input_echo_len, |
| 97 | "completion_tokens": total_len - input_echo_len, |
| 98 | "total_tokens": total_len, |
| 99 | }, |
| 100 | "finish_reason": "function_call" if stop_found else None, |
| 101 | } |
| 102 | |
| 103 | if stop_found: |
no test coverage detected