MCPcopy Index your code
hub / github.com/FoundationVision/ByteTrack / run_each_dataset

Function run_each_dataset

tutorials/ctracker/test_byte.py:47–104  ·  view source on GitHub ↗
(model_dir, retinanet, dataset_path, subset, cur_dataset)

Source from the content-addressed store, hash-verified

45 f.write(line)
46
47def run_each_dataset(model_dir, retinanet, dataset_path, subset, cur_dataset):
48 print(cur_dataset)
49
50 img_list = os.listdir(os.path.join(dataset_path, subset, cur_dataset, 'img1'))
51 img_list = [os.path.join(dataset_path, subset, cur_dataset, 'img1', _) for _ in img_list if ('jpg' in _) or ('png' in _)]
52 img_list = sorted(img_list)
53
54 img_len = len(img_list)
55 last_feat = None
56
57 confidence_threshold = 0.6
58 IOU_threshold = 0.5
59 retention_threshold = 10
60
61 det_list_all = []
62 tracklet_all = []
63 results = []
64 max_id = 0
65 max_draw_len = 100
66 draw_interval = 5
67 img_width = 1920
68 img_height = 1080
69 fps = 30
70
71 tracker = BYTETracker()
72
73 for idx in range((int(img_len / 2)), img_len + 1):
74 i = idx - 1
75 print('tracking: ', i)
76 with torch.no_grad():
77 data_path1 = img_list[min(idx, img_len - 1)]
78 img_origin1 = skimage.io.imread(data_path1)
79 img_h, img_w, _ = img_origin1.shape
80 img_height, img_width = img_h, img_w
81 resize_h, resize_w = math.ceil(img_h / 32) * 32, math.ceil(img_w / 32) * 32
82 img1 = np.zeros((resize_h, resize_w, 3), dtype=img_origin1.dtype)
83 img1[:img_h, :img_w, :] = img_origin1
84 img1 = (img1.astype(np.float32) / 255.0 - np.array([[RGB_MEAN]])) / np.array([[RGB_STD]])
85 img1 = torch.from_numpy(img1).permute(2, 0, 1).view(1, 3, resize_h, resize_w)
86 scores, transformed_anchors, last_feat = retinanet(img1.cuda().float(), last_feat=last_feat)
87
88 if idx > (int(img_len / 2)):
89 idxs = np.where(scores > 0.1)
90 # run tracking
91 online_targets = tracker.update(transformed_anchors[idxs[0], :4], scores[idxs[0]])
92 online_tlwhs = []
93 online_ids = []
94 online_scores = []
95 for t in online_targets:
96 tlwh = t.tlwh
97 tid = t.track_id
98 online_tlwhs.append(tlwh)
99 online_ids.append(tid)
100 online_scores.append(t.score)
101 results.append((idx, online_tlwhs, online_ids, online_scores))
102
103 fout_tracking = os.path.join(model_dir, 'results', cur_dataset + '.txt')
104 write_results(fout_tracking, results)

Callers 1

mainFunction · 0.70

Calls 3

updateMethod · 0.95
BYTETrackerClass · 0.90
write_resultsFunction · 0.70

Tested by

no test coverage detected