()
| 128 | |
| 129 | |
| 130 | def main(): |
| 131 | args = parse_args() |
| 132 | |
| 133 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): |
| 134 | raise ValueError('The output file must be a pkl file.') |
| 135 | |
| 136 | cfg = mmcv.Config.fromfile(args.config) |
| 137 | # set cudnn_benchmark |
| 138 | if cfg.get('cudnn_benchmark', False): |
| 139 | torch.backends.cudnn.benchmark = True |
| 140 | cfg.model.pretrained = None |
| 141 | cfg.data.test.test_mode = True |
| 142 | |
| 143 | # init distributed env first, since logger depends on the dist info. |
| 144 | if args.launcher == 'none': |
| 145 | distributed = False |
| 146 | else: |
| 147 | distributed = True |
| 148 | init_dist(args.launcher, **cfg.dist_params) |
| 149 | |
| 150 | # build the dataloader |
| 151 | # TODO: support multiple images per gpu (only minor changes are needed) |
| 152 | dataset = build_dataset(cfg.data.test) |
| 153 | data_loader = build_dataloader( |
| 154 | dataset, |
| 155 | imgs_per_gpu=1, |
| 156 | workers_per_gpu=cfg.data.workers_per_gpu, |
| 157 | dist=distributed, |
| 158 | shuffle=False) |
| 159 | |
| 160 | # build the model and load checkpoint |
| 161 | model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) |
| 162 | fp16_cfg = cfg.get('fp16', None) |
| 163 | if fp16_cfg is not None: |
| 164 | wrap_fp16_model(model) |
| 165 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') |
| 166 | # old versions did not save class info in checkpoint, this walkaround is |
| 167 | # for backward compatibility |
| 168 | if 'CLASSES' in checkpoint['meta']: |
| 169 | model.CLASSES = checkpoint['meta']['CLASSES'] |
| 170 | else: |
| 171 | model.CLASSES = dataset.CLASSES |
| 172 | |
| 173 | if not distributed: |
| 174 | model = MMDataParallel(model, device_ids=[0]) |
| 175 | outputs = single_gpu_test(model, data_loader, args.show) |
| 176 | else: |
| 177 | model = MMDistributedDataParallel(model.cuda()) |
| 178 | outputs = multi_gpu_test(model, data_loader, args.tmpdir) |
| 179 | |
| 180 | rank, _ = get_dist_info() |
| 181 | if args.out and rank == 0: |
| 182 | print('\nwriting results to {}'.format(args.out)) |
| 183 | mmcv.dump(outputs, args.out) |
| 184 | eval_types = args.eval |
| 185 | if eval_types: |
| 186 | print('Starting evaluate {}'.format(' and '.join(eval_types))) |
| 187 | if eval_types == ['proposal_fast']: |
no test coverage detected