| 185 | |
| 186 | |
| 187 | class Sort(object): |
| 188 | def __init__(self, det_thresh, max_age=30, min_hits=3, iou_threshold=0.3): |
| 189 | """ |
| 190 | Sets key parameters for SORT |
| 191 | """ |
| 192 | self.max_age = max_age |
| 193 | self.min_hits = min_hits |
| 194 | self.iou_threshold = iou_threshold |
| 195 | self.trackers = [] |
| 196 | self.frame_count = 0 |
| 197 | self.det_thresh = det_thresh |
| 198 | |
| 199 | def update(self, output_results, img_info, img_size): |
| 200 | """ |
| 201 | Params: |
| 202 | dets - a numpy array of detections in the format [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...] |
| 203 | Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections). |
| 204 | Returns the a similar array, where the last column is the object ID. |
| 205 | NOTE: The number of objects returned may differ from the number of detections provided. |
| 206 | """ |
| 207 | self.frame_count += 1 |
| 208 | # post_process detections |
| 209 | output_results = output_results.cpu().numpy() |
| 210 | scores = output_results[:, 4] * output_results[:, 5] |
| 211 | bboxes = output_results[:, :4] # x1y1x2y2 |
| 212 | img_h, img_w = img_info[0], img_info[1] |
| 213 | scale = min(img_size[0] / float(img_h), img_size[1] / float(img_w)) |
| 214 | bboxes /= scale |
| 215 | dets = np.concatenate((bboxes, np.expand_dims(scores, axis=-1)), axis=1) |
| 216 | remain_inds = scores > self.det_thresh |
| 217 | dets = dets[remain_inds] |
| 218 | # get predicted locations from existing trackers. |
| 219 | trks = np.zeros((len(self.trackers), 5)) |
| 220 | to_del = [] |
| 221 | ret = [] |
| 222 | for t, trk in enumerate(trks): |
| 223 | pos = self.trackers[t].predict()[0] |
| 224 | trk[:] = [pos[0], pos[1], pos[2], pos[3], 0] |
| 225 | if np.any(np.isnan(pos)): |
| 226 | to_del.append(t) |
| 227 | trks = np.ma.compress_rows(np.ma.masked_invalid(trks)) |
| 228 | for t in reversed(to_del): |
| 229 | self.trackers.pop(t) |
| 230 | matched, unmatched_dets, unmatched_trks = associate_detections_to_trackers(dets, trks, self.iou_threshold) |
| 231 | |
| 232 | # update matched trackers with assigned detections |
| 233 | for m in matched: |
| 234 | self.trackers[m[1]].update(dets[m[0], :]) |
| 235 | |
| 236 | # create and initialise new trackers for unmatched detections |
| 237 | for i in unmatched_dets: |
| 238 | trk = KalmanBoxTracker(dets[i,:]) |
| 239 | self.trackers.append(trk) |
| 240 | i = len(self.trackers) |
| 241 | for trk in reversed(self.trackers): |
| 242 | d = trk.get_state()[0] |
| 243 | if (trk.time_since_update < 1) and (trk.hit_streak >= self.min_hits or self.frame_count <= self.min_hits): |
| 244 | ret.append(np.concatenate((d,[trk.id+1])).reshape(1,-1)) # +1 as MOT benchmark requires positive |