(filtered_boxes, ocr_bbox, image_source, caption_model_processor)
| 124 | |
| 125 | |
| 126 | def get_parsed_content_icon_phi3v(filtered_boxes, ocr_bbox, image_source, caption_model_processor): |
| 127 | to_pil = ToPILImage() |
| 128 | if ocr_bbox: |
| 129 | non_ocr_boxes = filtered_boxes[len(ocr_bbox):] |
| 130 | else: |
| 131 | non_ocr_boxes = filtered_boxes |
| 132 | croped_pil_image = [] |
| 133 | for i, coord in enumerate(non_ocr_boxes): |
| 134 | xmin, xmax = int(coord[0]*image_source.shape[1]), int(coord[2]*image_source.shape[1]) |
| 135 | ymin, ymax = int(coord[1]*image_source.shape[0]), int(coord[3]*image_source.shape[0]) |
| 136 | cropped_image = image_source[ymin:ymax, xmin:xmax, :] |
| 137 | croped_pil_image.append(to_pil(cropped_image)) |
| 138 | |
| 139 | model, processor = caption_model_processor['model'], caption_model_processor['processor'] |
| 140 | device = model.device |
| 141 | messages = [{"role": "user", "content": "<|image_1|>\ndescribe the icon in one sentence"}] |
| 142 | prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| 143 | |
| 144 | batch_size = 5 # Number of samples per batch |
| 145 | generated_texts = [] |
| 146 | |
| 147 | for i in range(0, len(croped_pil_image), batch_size): |
| 148 | images = croped_pil_image[i:i+batch_size] |
| 149 | image_inputs = [processor.image_processor(x, return_tensors="pt") for x in images] |
| 150 | inputs ={'input_ids': [], 'attention_mask': [], 'pixel_values': [], 'image_sizes': []} |
| 151 | texts = [prompt] * len(images) |
| 152 | for i, txt in enumerate(texts): |
| 153 | input = processor._convert_images_texts_to_inputs(image_inputs[i], txt, return_tensors="pt") |
| 154 | inputs['input_ids'].append(input['input_ids']) |
| 155 | inputs['attention_mask'].append(input['attention_mask']) |
| 156 | inputs['pixel_values'].append(input['pixel_values']) |
| 157 | inputs['image_sizes'].append(input['image_sizes']) |
| 158 | max_len = max([x.shape[1] for x in inputs['input_ids']]) |
| 159 | for i, v in enumerate(inputs['input_ids']): |
| 160 | inputs['input_ids'][i] = torch.cat([processor.tokenizer.pad_token_id * torch.ones(1, max_len - v.shape[1], dtype=torch.long), v], dim=1) |
| 161 | inputs['attention_mask'][i] = torch.cat([torch.zeros(1, max_len - v.shape[1], dtype=torch.long), inputs['attention_mask'][i]], dim=1) |
| 162 | inputs_cat = {k: torch.concatenate(v).to(device) for k, v in inputs.items()} |
| 163 | |
| 164 | generation_args = { |
| 165 | "max_new_tokens": 25, |
| 166 | "temperature": 0.01, |
| 167 | "do_sample": False, |
| 168 | } |
| 169 | generate_ids = model.generate(**inputs_cat, eos_token_id=processor.tokenizer.eos_token_id, **generation_args) |
| 170 | # # remove input tokens |
| 171 | generate_ids = generate_ids[:, inputs_cat['input_ids'].shape[1]:] |
| 172 | response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
| 173 | response = [res.strip('\n').strip() for res in response] |
| 174 | generated_texts.extend(response) |
| 175 | |
| 176 | return generated_texts |
| 177 | |
| 178 | def remove_overlap(boxes, iou_threshold, ocr_bbox=None): |
| 179 | assert ocr_bbox is None or isinstance(ocr_bbox, List) |
no test coverage detected