MCPcopy
hub / github.com/alibaba/EasyCV / postprocess

Function postprocess

easycv/models/detection/utils/postprocess.py:122–164  ·  view source on GitHub ↗
(prediction, num_classes, conf_thre=0.7, nms_thre=0.45)

Source from the content-addressed store, hash-verified

120
121
122def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45):
123 box_corner = prediction.new(prediction.shape)
124 box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
125 box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
126 box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
127 box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
128 prediction[:, :, :4] = box_corner[:, :, :4]
129
130 output = [None for _ in range(len(prediction))]
131 for i, image_pred in enumerate(prediction):
132
133 # If none are remaining => process next image
134 if not image_pred.numel():
135 continue
136 # Get score and class with highest confidence
137 class_conf, class_pred = torch.max(
138 image_pred[:, 5:5 + num_classes], 1, keepdim=True)
139
140 conf_mask = (image_pred[:, 4] * class_conf.squeeze() >=
141 conf_thre).squeeze()
142 # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred)
143 detections = torch.cat(
144 (image_pred[:, :5], class_conf, class_pred.float()), 1)
145 detections = detections[conf_mask]
146 if not detections.numel():
147 continue
148
149 if LooseVersion(torchvision.__version__) >= LooseVersion('0.8.0'):
150 nms_out_index = torchvision.ops.batched_nms(
151 detections[:, :4], detections[:, 4] * detections[:, 5],
152 detections[:, 6], nms_thre)
153 else:
154 nms_out_index = torchvision.ops.nms(
155 detections[:, :4], detections[:, 4] * detections[:, 5],
156 nms_thre)
157
158 detections = detections[nms_out_index]
159 if output[i] is None:
160 output[i] = detections
161 else:
162 output[i] = torch.cat((output[i], detections))
163
164 return output

Callers 4

single_mnn_testFunction · 0.90
forward_testMethod · 0.90
forward_exportMethod · 0.90
process_singleMethod · 0.90

Calls 1

catMethod · 0.45

Tested by

no test coverage detected