(cfg, dataset, image_set, root_path, dataset_path,
ctx, prefix, epoch,
vis, shuffle, thresh, logger=None, output_path=None)
| 24 | |
| 25 | |
| 26 | def test_rpn(cfg, dataset, image_set, root_path, dataset_path, |
| 27 | ctx, prefix, epoch, |
| 28 | vis, shuffle, thresh, logger=None, output_path=None): |
| 29 | # set up logger |
| 30 | if not logger: |
| 31 | logging.basicConfig() |
| 32 | logger = logging.getLogger() |
| 33 | logger.setLevel(logging.INFO) |
| 34 | |
| 35 | # rpn generate proposal cfg |
| 36 | cfg.TEST.HAS_RPN = True |
| 37 | |
| 38 | # print cfg |
| 39 | pprint.pprint(cfg) |
| 40 | logger.info('testing rpn cfg:{}\n'.format(pprint.pformat(cfg))) |
| 41 | |
| 42 | # load symbol |
| 43 | sym_instance = eval(cfg.symbol + '.' + cfg.symbol)() |
| 44 | sym = sym_instance.get_symbol_rpn(cfg, is_train=False) |
| 45 | |
| 46 | # load dataset and prepare imdb for training |
| 47 | imdb = eval(dataset)(image_set, root_path, dataset_path, result_path=output_path) |
| 48 | roidb = imdb.gt_roidb() |
| 49 | test_data = TestLoader(roidb, cfg, batch_size=len(ctx), shuffle=shuffle, has_rpn=True) |
| 50 | |
| 51 | # load model |
| 52 | arg_params, aux_params = load_param(prefix, epoch) |
| 53 | |
| 54 | # infer shape |
| 55 | data_shape_dict = dict(test_data.provide_data_single) |
| 56 | sym_instance.infer_shape(data_shape_dict) |
| 57 | |
| 58 | # check parameters |
| 59 | sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict, is_train=False) |
| 60 | |
| 61 | # decide maximum shape |
| 62 | data_names = [k[0] for k in test_data.provide_data[0]] |
| 63 | label_names = None if test_data.provide_label[0] is None else [k[0] for k in test_data.provide_label[0]] |
| 64 | max_data_shape = [[('data', (1, 3, max([v[0] for v in cfg.SCALES]), max([v[1] for v in cfg.SCALES])))]] |
| 65 | |
| 66 | # create predictor |
| 67 | predictor = Predictor(sym, data_names, label_names, |
| 68 | context=ctx, max_data_shapes=max_data_shape, |
| 69 | provide_data=test_data.provide_data, provide_label=test_data.provide_label, |
| 70 | arg_params=arg_params, aux_params=aux_params) |
| 71 | |
| 72 | # start testing |
| 73 | imdb_boxes = generate_proposals(predictor, test_data, imdb, cfg, vis=vis, thresh=thresh) |
| 74 | |
| 75 | all_log_info = imdb.evaluate_recall(roidb, candidate_boxes=imdb_boxes) |
| 76 | logger.info(all_log_info) |
no test coverage detected