(model_dir, retinanet, dataset_path, subset, cur_dataset)
| 45 | f.write(line) |
| 46 | |
| 47 | def run_each_dataset(model_dir, retinanet, dataset_path, subset, cur_dataset): |
| 48 | print(cur_dataset) |
| 49 | |
| 50 | img_list = os.listdir(os.path.join(dataset_path, subset, cur_dataset, 'img1')) |
| 51 | img_list = [os.path.join(dataset_path, subset, cur_dataset, 'img1', _) for _ in img_list if ('jpg' in _) or ('png' in _)] |
| 52 | img_list = sorted(img_list) |
| 53 | |
| 54 | img_len = len(img_list) |
| 55 | last_feat = None |
| 56 | |
| 57 | confidence_threshold = 0.6 |
| 58 | IOU_threshold = 0.5 |
| 59 | retention_threshold = 10 |
| 60 | |
| 61 | det_list_all = [] |
| 62 | tracklet_all = [] |
| 63 | results = [] |
| 64 | max_id = 0 |
| 65 | max_draw_len = 100 |
| 66 | draw_interval = 5 |
| 67 | img_width = 1920 |
| 68 | img_height = 1080 |
| 69 | fps = 30 |
| 70 | |
| 71 | tracker = BYTETracker() |
| 72 | |
| 73 | for idx in range((int(img_len / 2)), img_len + 1): |
| 74 | i = idx - 1 |
| 75 | print('tracking: ', i) |
| 76 | with torch.no_grad(): |
| 77 | data_path1 = img_list[min(idx, img_len - 1)] |
| 78 | img_origin1 = skimage.io.imread(data_path1) |
| 79 | img_h, img_w, _ = img_origin1.shape |
| 80 | img_height, img_width = img_h, img_w |
| 81 | resize_h, resize_w = math.ceil(img_h / 32) * 32, math.ceil(img_w / 32) * 32 |
| 82 | img1 = np.zeros((resize_h, resize_w, 3), dtype=img_origin1.dtype) |
| 83 | img1[:img_h, :img_w, :] = img_origin1 |
| 84 | img1 = (img1.astype(np.float32) / 255.0 - np.array([[RGB_MEAN]])) / np.array([[RGB_STD]]) |
| 85 | img1 = torch.from_numpy(img1).permute(2, 0, 1).view(1, 3, resize_h, resize_w) |
| 86 | scores, transformed_anchors, last_feat = retinanet(img1.cuda().float(), last_feat=last_feat) |
| 87 | |
| 88 | if idx > (int(img_len / 2)): |
| 89 | idxs = np.where(scores > 0.1) |
| 90 | # run tracking |
| 91 | online_targets = tracker.update(transformed_anchors[idxs[0], :4], scores[idxs[0]]) |
| 92 | online_tlwhs = [] |
| 93 | online_ids = [] |
| 94 | online_scores = [] |
| 95 | for t in online_targets: |
| 96 | tlwh = t.tlwh |
| 97 | tid = t.track_id |
| 98 | online_tlwhs.append(tlwh) |
| 99 | online_ids.append(tid) |
| 100 | online_scores.append(t.score) |
| 101 | results.append((idx, online_tlwhs, online_ids, online_scores)) |
| 102 | |
| 103 | fout_tracking = os.path.join(model_dir, 'results', cur_dataset + '.txt') |
| 104 | write_results(fout_tracking, results) |
no test coverage detected