Return intersection-over-union (Jaccard index) of boxes. Both sets of boxes are expected to be in (x1, y1, x2, y2) format. Arguments: detections (Array[N, 6]), x1, y1, x2, y2, conf, class labels (Array[M, 5]), class, x1, y1, x2, y2 Returns:
(self, detections, labels)
| 132 | self.iou_thres = iou_thres |
| 133 | |
| 134 | def process_batch(self, detections, labels): |
| 135 | """ |
| 136 | Return intersection-over-union (Jaccard index) of boxes. |
| 137 | Both sets of boxes are expected to be in (x1, y1, x2, y2) format. |
| 138 | Arguments: |
| 139 | detections (Array[N, 6]), x1, y1, x2, y2, conf, class |
| 140 | labels (Array[M, 5]), class, x1, y1, x2, y2 |
| 141 | Returns: |
| 142 | None, updates confusion matrix accordingly |
| 143 | """ |
| 144 | if detections is None: |
| 145 | gt_classes = labels.int() |
| 146 | for gc in gt_classes: |
| 147 | self.matrix[self.nc, gc] += 1 # background FN |
| 148 | return |
| 149 | |
| 150 | detections = detections[detections[:, 4] > self.conf] |
| 151 | gt_classes = labels[:, 0].int() |
| 152 | detection_classes = detections[:, 5].int() |
| 153 | iou = box_iou(labels[:, 1:], detections[:, :4]) |
| 154 | |
| 155 | x = torch.where(iou > self.iou_thres) |
| 156 | if x[0].shape[0]: |
| 157 | matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() |
| 158 | if x[0].shape[0] > 1: |
| 159 | matches = matches[matches[:, 2].argsort()[::-1]] |
| 160 | matches = matches[np.unique(matches[:, 1], return_index=True)[1]] |
| 161 | matches = matches[matches[:, 2].argsort()[::-1]] |
| 162 | matches = matches[np.unique(matches[:, 0], return_index=True)[1]] |
| 163 | else: |
| 164 | matches = np.zeros((0, 3)) |
| 165 | |
| 166 | n = matches.shape[0] > 0 |
| 167 | m0, m1, _ = matches.transpose().astype(int) |
| 168 | for i, gc in enumerate(gt_classes): |
| 169 | j = m0 == i |
| 170 | if n and sum(j) == 1: |
| 171 | self.matrix[detection_classes[m1[j]], gc] += 1 # correct |
| 172 | else: |
| 173 | self.matrix[self.nc, gc] += 1 # true background |
| 174 | |
| 175 | if n: |
| 176 | for i, dc in enumerate(detection_classes): |
| 177 | if not any(m1 == i): |
| 178 | self.matrix[dc, self.nc] += 1 # predicted background |
| 179 | |
| 180 | def matrix(self): |
| 181 | return self.matrix |