(args=None)
| 106 | |
| 107 | |
| 108 | def main(args=None): |
| 109 | parser = argparse.ArgumentParser(description='Simple script for testing a CTracker network.') |
| 110 | parser.add_argument('--dataset_path', default='/dockerdata/home/jeromepeng/data/MOT/MOT17/', type=str, |
| 111 | help='Dataset path, location of the images sequence.') |
| 112 | parser.add_argument('--model_dir', default='./trained_model/', help='Path to model (.pt) file.') |
| 113 | parser.add_argument('--model_path', default='./trained_model/model_final.pth', help='Path to model (.pt) file.') |
| 114 | parser.add_argument('--seq_nums', default=0, type=int) |
| 115 | |
| 116 | parser = parser.parse_args(args) |
| 117 | |
| 118 | if not os.path.exists(os.path.join(parser.model_dir, 'results')): |
| 119 | os.makedirs(os.path.join(parser.model_dir, 'results')) |
| 120 | |
| 121 | retinanet = model.resnet50(num_classes=1, pretrained=True) |
| 122 | # retinanet_save = torch.load(os.path.join(parser.model_dir, 'model_final.pth')) |
| 123 | retinanet_save = torch.load(os.path.join(parser.model_path)) |
| 124 | |
| 125 | # rename moco pre-trained keys |
| 126 | state_dict = retinanet_save.state_dict() |
| 127 | for k in list(state_dict.keys()): |
| 128 | # retain only encoder up to before the embedding layer |
| 129 | if k.startswith('module.'): |
| 130 | # remove prefix |
| 131 | state_dict[k[len("module."):]] = state_dict[k] |
| 132 | # delete renamed or unused k |
| 133 | del state_dict[k] |
| 134 | |
| 135 | retinanet.load_state_dict(state_dict) |
| 136 | |
| 137 | use_gpu = True |
| 138 | |
| 139 | if use_gpu: retinanet = retinanet.cuda() |
| 140 | |
| 141 | retinanet.eval() |
| 142 | seq_nums = [] |
| 143 | if parser.seq_nums > 0: |
| 144 | seq_nums.append(parser.seq_nums) |
| 145 | else: |
| 146 | seq_nums = [2, 4, 5, 9, 10, 11, 13] |
| 147 | |
| 148 | for seq_num in seq_nums: |
| 149 | run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'train', 'MOT17-{:02d}'.format(seq_num)) |
| 150 | |
| 151 | |
| 152 | # for seq_num in [1, 3, 6, 7, 8, 12, 14]: |
no test coverage detected