check history and build inputs here
(self, tokenizer, question, history, generation_config, usr_id, bot_id)
| 879 | return response, history |
| 880 | |
| 881 | def build_inputs_for_chat(self, tokenizer, question, history, generation_config, usr_id, bot_id): |
| 882 | """ |
| 883 | check history and build inputs here |
| 884 | """ |
| 885 | # first tokenize question |
| 886 | q_token = tokenizer(question) |
| 887 | qa_history = copy.deepcopy(history) |
| 888 | |
| 889 | # get the max length we should build our inputs in |
| 890 | model_max_length = self.config.seq_length |
| 891 | build_max_length = max(0, model_max_length - generation_config.max_new_tokens) \ |
| 892 | if generation_config.max_new_tokens else max(0, generation_config.max_length) |
| 893 | if build_max_length < 3: |
| 894 | logger.warning("the model can not meet the requirements of input length,Please check config") |
| 895 | raise ValueError("") |
| 896 | |
| 897 | # trunc left |
| 898 | input_tokens = [usr_id] + q_token["input_ids"][-build_max_length + 1:] + [bot_id] |
| 899 | length = len(input_tokens) |
| 900 | |
| 901 | while len(qa_history) != 0: |
| 902 | message = qa_history.pop() |
| 903 | if message["role"] == "user": |
| 904 | tokens = [usr_id] + message["input_ids"] |
| 905 | elif message["role"] == "bot": |
| 906 | tokens = [bot_id] + message["input_ids"] + [generation_config.eos_token_id] |
| 907 | else: |
| 908 | tokens = [] |
| 909 | if len(tokens) + length >= build_max_length: |
| 910 | break |
| 911 | else: |
| 912 | input_tokens = tokens + input_tokens |
| 913 | |
| 914 | return torch.tensor([input_tokens], dtype=torch.int64) |