MCPcopy
hub / github.com/THUDM/LongWriter / chat

Method chat

train/patch/modeling_chatglm.py:876–895  ·  view source on GitHub ↗
(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
             max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
             **kwargs)

Source from the content-addressed store, hash-verified

874
875 @torch.inference_mode()
876 def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
877 max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
878 **kwargs):
879 if history is None:
880 history = []
881 if logits_processor is None:
882 logits_processor = LogitsProcessorList()
883 logits_processor.append(InvalidScoreLogitsProcessor())
884 gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p,
885 "temperature": temperature, "logits_processor": logits_processor, **kwargs}
886 inputs = tokenizer.build_chat_input(query, history=history, role=role)
887 inputs = inputs.to(self.device)
888 eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
889 tokenizer.get_command("<|observation|>")]
890 outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)
891 outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]
892 response = tokenizer.decode(outputs)
893 history.append({"role": role, "content": query})
894 response = self.process_response(response)
895 return response, history
896
897 def ppl(self,
898 input_ids: Optional[torch.Tensor] = None,

Callers 1

get_predFunction · 0.80

Calls 4

process_responseMethod · 0.95
build_chat_inputMethod · 0.80
get_commandMethod · 0.80

Tested by

no test coverage detected