MCPcopy
hub / github.com/ZFTurbo/Weighted-Boxes-Fusion / weighted_boxes_fusion

Function weighted_boxes_fusion

ensemble_boxes/ensemble_boxes_wbf.py:90–152  ·  view source on GitHub ↗

:param boxes_list: list of boxes predictions from each model, each box is 4 numbers. It has 3 dimensions (models_number, model_preds, 4) Order of boxes: x1, y1, x2, y2. We expect float normalized coordinates [0; 1] :param scores_list: list of scores for each model :param label

(boxes_list, scores_list, labels_list, weights=None, iou_thr=0.55, skip_box_thr=0.0, conf_type='avg', allows_overflow=False)

Source from the content-addressed store, hash-verified

88
89
90def weighted_boxes_fusion(boxes_list, scores_list, labels_list, weights=None, iou_thr=0.55, skip_box_thr=0.0, conf_type='avg', allows_overflow=False):
91 '''
92 :param boxes_list: list of boxes predictions from each model, each box is 4 numbers.
93 It has 3 dimensions (models_number, model_preds, 4)
94 Order of boxes: x1, y1, x2, y2. We expect float normalized coordinates [0; 1]
95 :param scores_list: list of scores for each model
96 :param labels_list: list of labels for each model
97 :param weights: list of weights for each model. Default: None, which means weight == 1 for each model
98 :param intersection_thr: IoU value for boxes to be a match
99 :param skip_box_thr: exclude boxes with score lower than this variable
100 :param conf_type: how to calculate confidence in weighted boxes. 'avg': average value, 'max': maximum value
101 :param allows_overflow: false if we want confidence score not exceed 1.0
102
103 :return: boxes: boxes coordinates (Order of boxes: x1, y1, x2, y2).
104 :return: scores: confidence scores
105 :return: labels: boxes labels
106 '''
107
108 if weights is None:
109 weights = np.ones(len(boxes_list))
110 if len(weights) != len(boxes_list):
111 print('Warning: incorrect number of weights {}. Must be: {}. Set weights equal to 1.'.format(len(weights), len(boxes_list)))
112 weights = np.ones(len(boxes_list))
113 weights = np.array(weights)
114
115 if conf_type not in ['avg', 'max']:
116 print('Unknown conf_type: {}. Must be "avg" or "max"'.format(conf_type))
117 exit()
118
119 filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr)
120 if len(filtered_boxes) == 0:
121 return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
122
123 overall_boxes = []
124 for label in filtered_boxes:
125 boxes = filtered_boxes[label]
126 new_boxes = []
127 weighted_boxes = []
128
129 # Clusterize boxes
130 for j in range(0, len(boxes)):
131 index, best_iou = find_matching_box(weighted_boxes, boxes[j], iou_thr)
132 if index != -1:
133 new_boxes[index].append(boxes[j])
134 weighted_boxes[index] = get_weighted_box(new_boxes[index], conf_type)
135 else:
136 new_boxes.append([boxes[j].copy()])
137 weighted_boxes.append(boxes[j].copy())
138
139 # Rescale confidence based on number of models and boxes
140 for i in range(len(new_boxes)):
141 if not allows_overflow:
142 weighted_boxes[i][1] = weighted_boxes[i][1] * min(weights.sum(), len(new_boxes[i])) / weights.sum()
143 else:
144 weighted_boxes[i][1] = weighted_boxes[i][1] * len(new_boxes[i]) / weights.sum()
145 overall_boxes.append(np.array(weighted_boxes))
146
147 overall_boxes = np.concatenate(overall_boxes, axis=0)

Callers 3

example_wbf_2_modelsFunction · 0.85
example_wbf_1_modelFunction · 0.85
process_single_idFunction · 0.85

Calls 3

prefilter_boxesFunction · 0.70
find_matching_boxFunction · 0.70
get_weighted_boxFunction · 0.70

Tested by

no test coverage detected