| 176 | return generated_texts |
| 177 | |
| 178 | def remove_overlap(boxes, iou_threshold, ocr_bbox=None): |
| 179 | assert ocr_bbox is None or isinstance(ocr_bbox, List) |
| 180 | |
| 181 | def box_area(box): |
| 182 | return (box[2] - box[0]) * (box[3] - box[1]) |
| 183 | |
| 184 | def intersection_area(box1, box2): |
| 185 | x1 = max(box1[0], box2[0]) |
| 186 | y1 = max(box1[1], box2[1]) |
| 187 | x2 = min(box1[2], box2[2]) |
| 188 | y2 = min(box1[3], box2[3]) |
| 189 | return max(0, x2 - x1) * max(0, y2 - y1) |
| 190 | |
| 191 | def IoU(box1, box2): |
| 192 | intersection = intersection_area(box1, box2) |
| 193 | union = box_area(box1) + box_area(box2) - intersection + 1e-6 |
| 194 | if box_area(box1) > 0 and box_area(box2) > 0: |
| 195 | ratio1 = intersection / box_area(box1) |
| 196 | ratio2 = intersection / box_area(box2) |
| 197 | else: |
| 198 | ratio1, ratio2 = 0, 0 |
| 199 | return max(intersection / union, ratio1, ratio2) |
| 200 | |
| 201 | def is_inside(box1, box2): |
| 202 | # return box1[0] >= box2[0] and box1[1] >= box2[1] and box1[2] <= box2[2] and box1[3] <= box2[3] |
| 203 | intersection = intersection_area(box1, box2) |
| 204 | ratio1 = intersection / box_area(box1) |
| 205 | return ratio1 > 0.95 |
| 206 | |
| 207 | boxes = boxes.tolist() |
| 208 | filtered_boxes = [] |
| 209 | if ocr_bbox: |
| 210 | filtered_boxes.extend(ocr_bbox) |
| 211 | # print('ocr_bbox!!!', ocr_bbox) |
| 212 | for i, box1 in enumerate(boxes): |
| 213 | # if not any(IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2) for j, box2 in enumerate(boxes) if i != j): |
| 214 | is_valid_box = True |
| 215 | for j, box2 in enumerate(boxes): |
| 216 | # keep the smaller box |
| 217 | if i != j and IoU(box1, box2) > iou_threshold and box_area(box1) > box_area(box2): |
| 218 | is_valid_box = False |
| 219 | break |
| 220 | if is_valid_box: |
| 221 | # add the following 2 lines to include ocr bbox |
| 222 | if ocr_bbox: |
| 223 | # only add the box if it does not overlap with any ocr bbox |
| 224 | if not any(IoU(box1, box3) > iou_threshold and not is_inside(box1, box3) for k, box3 in enumerate(ocr_bbox)): |
| 225 | filtered_boxes.append(box1) |
| 226 | else: |
| 227 | filtered_boxes.append(box1) |
| 228 | return torch.tensor(filtered_boxes) |
| 229 | |
| 230 | |
| 231 | def remove_overlap_new(boxes, iou_threshold, ocr_bbox=None): |