(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user",
max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None,
**kwargs)
| 874 | |
| 875 | @torch.inference_mode() |
| 876 | def chat(self, tokenizer, query: str, history: List[Dict] = None, role: str = "user", |
| 877 | max_length: int = 8192, num_beams=1, do_sample=True, top_p=0.8, temperature=0.8, logits_processor=None, |
| 878 | **kwargs): |
| 879 | if history is None: |
| 880 | history = [] |
| 881 | if logits_processor is None: |
| 882 | logits_processor = LogitsProcessorList() |
| 883 | logits_processor.append(InvalidScoreLogitsProcessor()) |
| 884 | gen_kwargs = {"max_length": max_length, "num_beams": num_beams, "do_sample": do_sample, "top_p": top_p, |
| 885 | "temperature": temperature, "logits_processor": logits_processor, **kwargs} |
| 886 | inputs = tokenizer.build_chat_input(query, history=history, role=role) |
| 887 | inputs = inputs.to(self.device) |
| 888 | eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), |
| 889 | tokenizer.get_command("<|observation|>")] |
| 890 | outputs = self.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id) |
| 891 | outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1] |
| 892 | response = tokenizer.decode(outputs) |
| 893 | history.append({"role": role, "content": query}) |
| 894 | response = self.process_response(response) |
| 895 | return response, history |
| 896 | |
| 897 | def ppl(self, |
| 898 | input_ids: Optional[torch.Tensor] = None, |
no test coverage detected