MCPcopy Index your code
hub / github.com/FoundationVision/ByteTrack / main

Function main

tutorials/ctracker/test_byte.py:108–149  ·  view source on GitHub ↗
(args=None)

Source from the content-addressed store, hash-verified

106
107
108def main(args=None):
109 parser = argparse.ArgumentParser(description='Simple script for testing a CTracker network.')
110 parser.add_argument('--dataset_path', default='/dockerdata/home/jeromepeng/data/MOT/MOT17/', type=str,
111 help='Dataset path, location of the images sequence.')
112 parser.add_argument('--model_dir', default='./trained_model/', help='Path to model (.pt) file.')
113 parser.add_argument('--model_path', default='./trained_model/model_final.pth', help='Path to model (.pt) file.')
114 parser.add_argument('--seq_nums', default=0, type=int)
115
116 parser = parser.parse_args(args)
117
118 if not os.path.exists(os.path.join(parser.model_dir, 'results')):
119 os.makedirs(os.path.join(parser.model_dir, 'results'))
120
121 retinanet = model.resnet50(num_classes=1, pretrained=True)
122 # retinanet_save = torch.load(os.path.join(parser.model_dir, 'model_final.pth'))
123 retinanet_save = torch.load(os.path.join(parser.model_path))
124
125 # rename moco pre-trained keys
126 state_dict = retinanet_save.state_dict()
127 for k in list(state_dict.keys()):
128 # retain only encoder up to before the embedding layer
129 if k.startswith('module.'):
130 # remove prefix
131 state_dict[k[len("module."):]] = state_dict[k]
132 # delete renamed or unused k
133 del state_dict[k]
134
135 retinanet.load_state_dict(state_dict)
136
137 use_gpu = True
138
139 if use_gpu: retinanet = retinanet.cuda()
140
141 retinanet.eval()
142 seq_nums = []
143 if parser.seq_nums > 0:
144 seq_nums.append(parser.seq_nums)
145 else:
146 seq_nums = [2, 4, 5, 9, 10, 11, 13]
147
148 for seq_num in seq_nums:
149 run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'train', 'MOT17-{:02d}'.format(seq_num))
150
151
152# for seq_num in [1, 3, 6, 7, 8, 12, 14]:

Callers 1

test_byte.pyFile · 0.70

Calls 2

run_each_datasetFunction · 0.70
evalMethod · 0.45

Tested by

no test coverage detected