(cfg, dataset, image_set, root_path, dataset_path,
ctx, prefix, epoch,
vis, ignore_cache, shuffle, has_rpn, proposal, thresh, logger=None, output_path=None)
| 26 | |
| 27 | |
| 28 | def test_rcnn(cfg, dataset, image_set, root_path, dataset_path, |
| 29 | ctx, prefix, epoch, |
| 30 | vis, ignore_cache, shuffle, has_rpn, proposal, thresh, logger=None, output_path=None): |
| 31 | if not logger: |
| 32 | assert False, 'require a logger' |
| 33 | |
| 34 | # print cfg |
| 35 | pprint.pprint(cfg) |
| 36 | logger.info('testing cfg:{}\n'.format(pprint.pformat(cfg))) |
| 37 | |
| 38 | # load symbol and testing data |
| 39 | if has_rpn: |
| 40 | sym_instance = eval(cfg.symbol + '.' + cfg.symbol)() |
| 41 | sym = sym_instance.get_symbol(cfg, is_train=False) |
| 42 | imdb = eval(dataset)(image_set, root_path, dataset_path, result_path=output_path) |
| 43 | roidb = imdb.gt_roidb() |
| 44 | else: |
| 45 | sym_instance = eval(cfg.symbol + '.' + cfg.symbol)() |
| 46 | sym = sym_instance.get_symbol_rcnn(cfg, is_train=False) |
| 47 | imdb = eval(dataset)(image_set, root_path, dataset_path, result_path=output_path) |
| 48 | gt_roidb = imdb.gt_roidb() |
| 49 | roidb = eval('imdb.' + proposal + '_roidb')(gt_roidb) |
| 50 | |
| 51 | # get test data iter |
| 52 | test_data = TestLoader(roidb, cfg, batch_size=len(ctx), shuffle=shuffle, has_rpn=has_rpn) |
| 53 | |
| 54 | # load model |
| 55 | arg_params, aux_params = load_param(prefix, epoch, process=True) |
| 56 | |
| 57 | # infer shape |
| 58 | data_shape_dict = dict(test_data.provide_data_single) |
| 59 | sym_instance.infer_shape(data_shape_dict) |
| 60 | |
| 61 | sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict, is_train=False) |
| 62 | |
| 63 | # decide maximum shape |
| 64 | data_names = [k[0] for k in test_data.provide_data_single] |
| 65 | label_names = None |
| 66 | max_data_shape = [[('data', (1, 3, max([v[0] for v in cfg.SCALES]), max([v[1] for v in cfg.SCALES])))]] |
| 67 | if not has_rpn: |
| 68 | max_data_shape.append(('rois', (cfg.TEST.PROPOSAL_POST_NMS_TOP_N + 30, 5))) |
| 69 | |
| 70 | # create predictor |
| 71 | predictor = Predictor(sym, data_names, label_names, |
| 72 | context=ctx, max_data_shapes=max_data_shape, |
| 73 | provide_data=test_data.provide_data, provide_label=test_data.provide_label, |
| 74 | arg_params=arg_params, aux_params=aux_params) |
| 75 | |
| 76 | # start detection |
| 77 | pred_eval(predictor, test_data, imdb, cfg, vis=vis, ignore_cache=ignore_cache, thresh=thresh, logger=logger) |
| 78 |
no test coverage detected