(prediction, num_classes, conf_thre=0.7, nms_thre=0.45)
| 120 | |
| 121 | |
| 122 | def postprocess(prediction, num_classes, conf_thre=0.7, nms_thre=0.45): |
| 123 | box_corner = prediction.new(prediction.shape) |
| 124 | box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2 |
| 125 | box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2 |
| 126 | box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2 |
| 127 | box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2 |
| 128 | prediction[:, :, :4] = box_corner[:, :, :4] |
| 129 | |
| 130 | output = [None for _ in range(len(prediction))] |
| 131 | for i, image_pred in enumerate(prediction): |
| 132 | |
| 133 | # If none are remaining => process next image |
| 134 | if not image_pred.numel(): |
| 135 | continue |
| 136 | # Get score and class with highest confidence |
| 137 | class_conf, class_pred = torch.max( |
| 138 | image_pred[:, 5:5 + num_classes], 1, keepdim=True) |
| 139 | |
| 140 | conf_mask = (image_pred[:, 4] * class_conf.squeeze() >= |
| 141 | conf_thre).squeeze() |
| 142 | # Detections ordered as (x1, y1, x2, y2, obj_conf, class_conf, class_pred) |
| 143 | detections = torch.cat( |
| 144 | (image_pred[:, :5], class_conf, class_pred.float()), 1) |
| 145 | detections = detections[conf_mask] |
| 146 | if not detections.numel(): |
| 147 | continue |
| 148 | |
| 149 | if LooseVersion(torchvision.__version__) >= LooseVersion('0.8.0'): |
| 150 | nms_out_index = torchvision.ops.batched_nms( |
| 151 | detections[:, :4], detections[:, 4] * detections[:, 5], |
| 152 | detections[:, 6], nms_thre) |
| 153 | else: |
| 154 | nms_out_index = torchvision.ops.nms( |
| 155 | detections[:, :4], detections[:, 4] * detections[:, 5], |
| 156 | nms_thre) |
| 157 | |
| 158 | detections = detections[nms_out_index] |
| 159 | if output[i] is None: |
| 160 | output[i] = detections |
| 161 | else: |
| 162 | output[i] = torch.cat((output[i], detections)) |
| 163 | |
| 164 | return output |
no test coverage detected