(self, message, dataset=None)
| 136 | return answer |
| 137 | |
| 138 | def generate_inner(self, message, dataset=None): |
| 139 | query, image_paths = self.prepare_inputs(message) |
| 140 | images_list = [Image.open(image_path).convert('RGB') for image_path in image_paths] |
| 141 | args = abstractproperty() |
| 142 | args.image_aspect_ratio = 'pad' |
| 143 | image_tensors = self.process_images(images_list, self.image_processor, args).cuda() |
| 144 | prompt, input_ids = self.conversation_formatter.format_query(query) |
| 145 | input_ids = input_ids.unsqueeze(0).cuda() |
| 146 | |
| 147 | with torch.inference_mode(): |
| 148 | kwargs = dict( |
| 149 | images=image_tensors, |
| 150 | ) |
| 151 | kwargs.update(self.kwargs) |
| 152 | output_ids = self.model.generate(input_ids, **kwargs) |
| 153 | |
| 154 | input_token_len = input_ids.shape[1] |
| 155 | n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item() |
| 156 | if n_diff_input_output > 0: |
| 157 | print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids') |
| 158 | response = self.tokenizer.batch_decode(output_ids[:, input_token_len:], |
| 159 | skip_special_tokens=True)[0].strip(string.whitespace) |
| 160 | answer = response |
| 161 | |
| 162 | if query.endswith("Answer with the option's letter from the given choices directly.") or query.endswith( |
| 163 | '请直接回答选项字母。'): |
| 164 | qtype = 'multiple-choice' |
| 165 | while True: |
| 166 | answer = answer.strip(string.punctuation + string.whitespace) |
| 167 | if len(answer) > 1: |
| 168 | if answer[0] in string.ascii_uppercase and answer[1] in string.whitespace + string.punctuation: |
| 169 | answer = answer[0] |
| 170 | break |
| 171 | elif answer[-1] in string.ascii_uppercase and answer[-2] in string.whitespace + string.punctuation: |
| 172 | answer = answer[-1] |
| 173 | break |
| 174 | elif listinstr(['answer is', 'answer:'], answer.lower()): |
| 175 | answer = self.process_answer_prefix(answer, ['answer is', 'answer:']) |
| 176 | answer = self.process_answer_prefix(answer, ['option']) |
| 177 | else: |
| 178 | break |
| 179 | else: |
| 180 | break |
| 181 | else: |
| 182 | qtype = 'open' |
| 183 | |
| 184 | if self.count % 50 == 0 and int(os.environ.get('LOCAL_RANK', '0')) == 0: |
| 185 | print(f'\n{self.BEGIN_LINE}') |
| 186 | print(f'image_paths: {image_paths}\n') |
| 187 | print(f'prompt: {prompt}\n') |
| 188 | print(f'qtype: {qtype}\n') |
| 189 | print(f'output: {response}\n') |
| 190 | print(f'answer: {answer}\n') |
| 191 | print(f'{self.END_LINE}\n', flush=True) |
| 192 | |
| 193 | self.count += 1 |
| 194 | |
| 195 | return answer |
nothing calls this directly
no test coverage detected