(
image_input,
box_threshold,
iou_threshold,
use_paddleocr,
imgsz,
instruction,
)
| 146 | @torch.inference_mode() |
| 147 | # @torch.autocast(device_type="cuda", dtype=torch.bfloat16) |
| 148 | def process( |
| 149 | image_input, |
| 150 | box_threshold, |
| 151 | iou_threshold, |
| 152 | use_paddleocr, |
| 153 | imgsz, |
| 154 | instruction, |
| 155 | ) -> Optional[Image.Image]: |
| 156 | |
| 157 | # image_save_path = 'imgs/saved_image_demo.png' |
| 158 | # image_input.save(image_save_path) |
| 159 | # image = Image.open(image_save_path) |
| 160 | box_overlay_ratio = image_input.size[0] / 3200 |
| 161 | draw_bbox_config = { |
| 162 | 'text_scale': 0.8 * box_overlay_ratio, |
| 163 | 'text_thickness': max(int(2 * box_overlay_ratio), 1), |
| 164 | 'text_padding': max(int(3 * box_overlay_ratio), 1), |
| 165 | 'thickness': max(int(3 * box_overlay_ratio), 1), |
| 166 | } |
| 167 | |
| 168 | ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_input, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=use_paddleocr) |
| 169 | text, ocr_bbox = ocr_bbox_rslt |
| 170 | dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_input, yolo_model, BOX_TRESHOLD = box_threshold, output_coord_in_ratio=False, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,iou_threshold=iou_threshold, imgsz=imgsz,) |
| 171 | parsed_content_list = '\n'.join([f'icon {i}: ' + str(v) for i,v in enumerate(parsed_content_list)]) |
| 172 | |
| 173 | if len(instruction) == 0: |
| 174 | print('finish processing') |
| 175 | image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) |
| 176 | return image, str(parsed_content_list) |
| 177 | |
| 178 | elif instruction.startswith('Q:'): |
| 179 | response = get_qa_response(instruction, image_input) |
| 180 | return image_input, response |
| 181 | |
| 182 | # parsed_content_list = str(parsed_content_list) |
| 183 | # convert xywh to yxhw |
| 184 | label_coordinates_yxhw = {} |
| 185 | for key, val in label_coordinates.items(): |
| 186 | if val[2] < 0 or val[3] < 0: |
| 187 | continue |
| 188 | label_coordinates_yxhw[key] = [val[1], val[0], val[3], val[2]] |
| 189 | image_som = plot_boxes_with_marks(image_input.copy(), [val for key, val in label_coordinates_yxhw.items()], som_generator, edgecolor=(255,0,0), fn_save=None, normalized_to_pixel=False) |
| 190 | |
| 191 | # convert xywh to xyxy |
| 192 | for key, val in label_coordinates.items(): |
| 193 | label_coordinates[key] = [val[0], val[1], val[0] + val[2], val[1] + val[3]] |
| 194 | |
| 195 | # normalize label_coordinates |
| 196 | for key, val in label_coordinates.items(): |
| 197 | label_coordinates[key] = [val[0] / image_input.size[0], val[1] / image_input.size[1], val[2] / image_input.size[0], val[3] / image_input.size[1]] |
| 198 | |
| 199 | magma_response = get_som_response(instruction, image_som) |
| 200 | print("magma repsonse: ", magma_response) |
| 201 | |
| 202 | # map magma_response into the mark id |
| 203 | mark_id = extract_mark_id(magma_response) |
| 204 | if mark_id is not None: |
| 205 | if str(mark_id) in label_coordinates: |
nothing calls this directly
no test coverage detected