(args=None)
| 297 | run_each_dataset(model_dir, retinanet, root_path, 'test', 'MOT17-{:02d}'.format(seq_num)) |
| 298 | |
| 299 | def main(args=None): |
| 300 | parser = argparse.ArgumentParser(description='Simple script for testing a CTracker network.') |
| 301 | parser.add_argument('--dataset_path', default='/dockerdata/home/jeromepeng/data/MOT/MOT17/', type=str, help='Dataset path, location of the images sequence.') |
| 302 | parser.add_argument('--model_dir', default='./trained_model/', help='Path to model (.pt) file.') |
| 303 | parser.add_argument('--model_path', default='./trained_model/model_final.pth', help='Path to model (.pt) file.') |
| 304 | parser = parser.parse_args(args) |
| 305 | |
| 306 | if not os.path.exists(os.path.join(parser.model_dir, 'results')): |
| 307 | os.makedirs(os.path.join(parser.model_dir, 'results')) |
| 308 | |
| 309 | retinanet = model.resnet50(num_classes=1, pretrained=True) |
| 310 | # retinanet_save = torch.load(os.path.join(parser.model_dir, 'model_final.pth')) |
| 311 | retinanet_save = torch.load(os.path.join(parser.model_path)) |
| 312 | |
| 313 | # rename moco pre-trained keys |
| 314 | state_dict = retinanet_save.state_dict() |
| 315 | for k in list(state_dict.keys()): |
| 316 | # retain only encoder up to before the embedding layer |
| 317 | if k.startswith('module.'): |
| 318 | # remove prefix |
| 319 | state_dict[k[len("module."):]] = state_dict[k] |
| 320 | # delete renamed or unused k |
| 321 | del state_dict[k] |
| 322 | |
| 323 | retinanet.load_state_dict(state_dict) |
| 324 | |
| 325 | use_gpu = True |
| 326 | |
| 327 | if use_gpu: retinanet = retinanet.cuda() |
| 328 | |
| 329 | retinanet.eval() |
| 330 | |
| 331 | for seq_num in [2, 4, 5, 9, 10, 11, 13]: |
| 332 | run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'train', 'MOT17-{:02d}'.format(seq_num)) |
| 333 | # for seq_num in [1, 3, 6, 7, 8, 12, 14]: |
| 334 | # run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'test', 'MOT17-{:02d}'.format(seq_num)) |
| 335 |
no test coverage detected