| 7 | |
| 8 | |
| 9 | class Evaluator(object): |
| 10 | |
| 11 | def __init__(self, data_root, seq_name, data_type): |
| 12 | self.data_root = data_root |
| 13 | self.seq_name = seq_name |
| 14 | self.data_type = data_type |
| 15 | |
| 16 | self.load_annotations() |
| 17 | self.reset_accumulator() |
| 18 | |
| 19 | def load_annotations(self): |
| 20 | assert self.data_type == 'mot' |
| 21 | |
| 22 | gt_filename = os.path.join(self.data_root, self.seq_name, 'gt', 'gt.txt') |
| 23 | self.gt_frame_dict = read_results(gt_filename, self.data_type, is_gt=True) |
| 24 | self.gt_ignore_frame_dict = read_results(gt_filename, self.data_type, is_ignore=True) |
| 25 | |
| 26 | def reset_accumulator(self): |
| 27 | self.acc = mm.MOTAccumulator(auto_id=True) |
| 28 | |
| 29 | def eval_frame(self, frame_id, trk_tlwhs, trk_ids, rtn_events=False): |
| 30 | # results |
| 31 | trk_tlwhs = np.copy(trk_tlwhs) |
| 32 | trk_ids = np.copy(trk_ids) |
| 33 | |
| 34 | # gts |
| 35 | gt_objs = self.gt_frame_dict.get(frame_id, []) |
| 36 | gt_tlwhs, gt_ids = unzip_objs(gt_objs)[:2] |
| 37 | |
| 38 | # ignore boxes |
| 39 | ignore_objs = self.gt_ignore_frame_dict.get(frame_id, []) |
| 40 | ignore_tlwhs = unzip_objs(ignore_objs)[0] |
| 41 | |
| 42 | |
| 43 | # remove ignored results |
| 44 | keep = np.ones(len(trk_tlwhs), dtype=bool) |
| 45 | iou_distance = mm.distances.iou_matrix(ignore_tlwhs, trk_tlwhs, max_iou=0.5) |
| 46 | if len(iou_distance) > 0: |
| 47 | match_is, match_js = mm.lap.linear_sum_assignment(iou_distance) |
| 48 | match_is, match_js = map(lambda a: np.asarray(a, dtype=int), [match_is, match_js]) |
| 49 | match_ious = iou_distance[match_is, match_js] |
| 50 | |
| 51 | match_js = np.asarray(match_js, dtype=int) |
| 52 | match_js = match_js[np.logical_not(np.isnan(match_ious))] |
| 53 | keep[match_js] = False |
| 54 | trk_tlwhs = trk_tlwhs[keep] |
| 55 | trk_ids = trk_ids[keep] |
| 56 | |
| 57 | # get distance matrix |
| 58 | iou_distance = mm.distances.iou_matrix(gt_tlwhs, trk_tlwhs, max_iou=0.5) |
| 59 | |
| 60 | # acc |
| 61 | self.acc.update(gt_ids, trk_ids, iou_distance) |
| 62 | |
| 63 | if rtn_events and iou_distance.size > 0 and hasattr(self.acc, 'last_mot_events'): |
| 64 | events = self.acc.last_mot_events # only supported by https://github.com/longcw/py-motmetrics |
| 65 | else: |
| 66 | events = None |