MCPcopy Index your code
hub / github.com/FoundationVision/ByteTrack / step

Method step

tutorials/transtrack/tracker.py:54–191  ·  view source on GitHub ↗
(self, output_results)

Source from the content-addressed store, hash-verified

52
53
54 def step(self, output_results):
55 scores = output_results["scores"]
56 bboxes = output_results["boxes"] # x1y1x2y2
57 track_bboxes = output_results["track_boxes"] if "track_boxes" in output_results else None # x1y1x2y2
58
59 results = list()
60 results_dict = dict()
61 results_second = list()
62
63 tracks = list()
64
65 for idx in range(scores.shape[0]):
66 if idx in self.tracks_dict and track_bboxes is not None:
67 self.tracks_dict[idx]["bbox"] = track_bboxes[idx, :].cpu().numpy().tolist()
68
69 if scores[idx] >= self.score_thresh:
70 obj = dict()
71 obj["score"] = float(scores[idx])
72 obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist()
73 results.append(obj)
74 results_dict[idx] = obj
75 elif scores[idx] >= self.low_thresh:
76 second_obj = dict()
77 second_obj["score"] = float(scores[idx])
78 second_obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist()
79 results_second.append(second_obj)
80 results_dict[idx] = second_obj
81
82 tracks = [v for v in self.tracks_dict.values()] + self.unmatched_tracks
83 # for trackss in tracks:
84 # print(trackss.keys())
85 N = len(results)
86 M = len(tracks)
87
88 ret = list()
89 unmatched_tracks = [t for t in range(M)]
90 unmatched_dets = [d for d in range(N)]
91
92 if N > 0 and M > 0:
93 det_box = torch.stack([torch.tensor(obj['bbox']) for obj in results], dim=0) # N x 4
94 track_box = torch.stack([torch.tensor(obj['bbox']) for obj in tracks], dim=0) # M x 4
95 cost_bbox = 1.0 - box_ops.generalized_box_iou(det_box, track_box) # N x M
96
97 matched_indices = linear_sum_assignment(cost_bbox)
98 unmatched_dets = [d for d in range(N) if not (d in matched_indices[0])]
99 unmatched_tracks = [d for d in range(M) if not (d in matched_indices[1])]
100
101 matches = [[],[]]
102 for (m0, m1) in zip(matched_indices[0], matched_indices[1]):
103 if cost_bbox[m0, m1] > 1.2:
104 unmatched_dets.append(m0)
105 unmatched_tracks.append(m1)
106 else:
107 matches[0].append(m0)
108 matches[1].append(m1)
109
110 for (m0, m1) in zip(matches[0], matches[1]):
111 track = results[m0]

Callers 3

train_one_epochFunction · 0.45
evaluateFunction · 0.45
mainFunction · 0.45

Calls

no outgoing calls

Tested by

no test coverage detected