Args: tokenizer: the tokenizer of telechat question: question which the model reply in this turn history: history which will format the input for telechat stream: if return the full text at last or yield the text in token generat
(self, tokenizer, question: str = '', history: Union[List[Dict], History] = None, stream: bool = False,
generation_config: Optional[GenerationConfig] = None, **kwargs)
| 831 | ) |
| 832 | |
| 833 | def chat(self, tokenizer, question: str = '', history: Union[List[Dict], History] = None, stream: bool = False, |
| 834 | generation_config: Optional[GenerationConfig] = None, **kwargs): |
| 835 | """ |
| 836 | Args: |
| 837 | tokenizer: the tokenizer of telechat |
| 838 | question: question which the model reply in this turn |
| 839 | history: history which will format the input for telechat |
| 840 | stream: if return the full text at last or yield the text in token |
| 841 | generation_config: configuration for generation |
| 842 | **kwargs: args which will update the generation config or pass to model forward |
| 843 | """ |
| 844 | generation_config = generation_config or self.generation_config |
| 845 | if not generation_config: |
| 846 | logger.error("generation_config is None") |
| 847 | raise ValueError("generation_config must not be None") |
| 848 | if not question: |
| 849 | logger.error("question is empty") |
| 850 | raise ValueError("question must not be empty") |
| 851 | if history is None: |
| 852 | history = [] |
| 853 | |
| 854 | # we update and check generate_config here for building inputs. |
| 855 | |
| 856 | generation_config = copy.deepcopy(generation_config) |
| 857 | user_id = generation_config.user_token_id |
| 858 | bot_id = generation_config.bot_token_id |
| 859 | model_kwargs = generation_config.update(**kwargs) |
| 860 | generation_config.validate() |
| 861 | |
| 862 | # transfer to History |
| 863 | if not isinstance(history, History): |
| 864 | history = History(tokenizer, history) |
| 865 | |
| 866 | inputs = self.build_inputs_for_chat(tokenizer, question, history, generation_config, user_id, bot_id) |
| 867 | history.append({"role": "user", "content": question}) |
| 868 | if stream: |
| 869 | streamer = TelechatIterTextStreamer(tokenizer, history,skip_prompt=True) |
| 870 | Thread(target=self.generate, kwargs=dict( |
| 871 | inputs=inputs.to(self.device), streamer=streamer, |
| 872 | generation_config=generation_config, **model_kwargs |
| 873 | )).start() |
| 874 | return streamer |
| 875 | else: |
| 876 | outputs = self.generate(inputs.to(self.device), generation_config=generation_config, **model_kwargs) |
| 877 | response = tokenizer.decode(outputs[0][len(inputs[0]):-1]) |
| 878 | history.append({"role": "bot", "content": response}) |
| 879 | return response, history |
| 880 | |
| 881 | def build_inputs_for_chat(self, tokenizer, question, history, generation_config, usr_id, bot_id): |
| 882 | """ |
nothing calls this directly
no test coverage detected