MCPcopy
hub / github.com/msracver/Deformable-ConvNets / test_rpn

Function test_rpn

rfcn/function/test_rpn.py:26–76  ·  view source on GitHub ↗
(cfg, dataset, image_set, root_path, dataset_path,
             ctx, prefix, epoch,
             vis, shuffle, thresh, logger=None, output_path=None)

Source from the content-addressed store, hash-verified

24
25
26def test_rpn(cfg, dataset, image_set, root_path, dataset_path,
27 ctx, prefix, epoch,
28 vis, shuffle, thresh, logger=None, output_path=None):
29 # set up logger
30 if not logger:
31 logging.basicConfig()
32 logger = logging.getLogger()
33 logger.setLevel(logging.INFO)
34
35 # rpn generate proposal cfg
36 cfg.TEST.HAS_RPN = True
37
38 # print cfg
39 pprint.pprint(cfg)
40 logger.info('testing rpn cfg:{}\n'.format(pprint.pformat(cfg)))
41
42 # load symbol
43 sym_instance = eval(cfg.symbol + '.' + cfg.symbol)()
44 sym = sym_instance.get_symbol_rpn(cfg, is_train=False)
45
46 # load dataset and prepare imdb for training
47 imdb = eval(dataset)(image_set, root_path, dataset_path, result_path=output_path)
48 roidb = imdb.gt_roidb()
49 test_data = TestLoader(roidb, cfg, batch_size=len(ctx), shuffle=shuffle, has_rpn=True)
50
51 # load model
52 arg_params, aux_params = load_param(prefix, epoch)
53
54 # infer shape
55 data_shape_dict = dict(test_data.provide_data_single)
56 sym_instance.infer_shape(data_shape_dict)
57
58 # check parameters
59 sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict, is_train=False)
60
61 # decide maximum shape
62 data_names = [k[0] for k in test_data.provide_data[0]]
63 label_names = None if test_data.provide_label[0] is None else [k[0] for k in test_data.provide_label[0]]
64 max_data_shape = [[('data', (1, 3, max([v[0] for v in cfg.SCALES]), max([v[1] for v in cfg.SCALES])))]]
65
66 # create predictor
67 predictor = Predictor(sym, data_names, label_names,
68 context=ctx, max_data_shapes=max_data_shape,
69 provide_data=test_data.provide_data, provide_label=test_data.provide_label,
70 arg_params=arg_params, aux_params=aux_params)
71
72 # start testing
73 imdb_boxes = generate_proposals(predictor, test_data, imdb, cfg, vis=vis, thresh=thresh)
74
75 all_log_info = imdb.evaluate_recall(roidb, candidate_boxes=imdb_boxes)
76 logger.info(all_log_info)

Callers 1

alternate_trainFunction · 0.90

Calls 10

TestLoaderClass · 0.90
load_paramFunction · 0.90
PredictorClass · 0.90
generate_proposalsFunction · 0.90
infoMethod · 0.80
evaluate_recallMethod · 0.80
get_symbol_rpnMethod · 0.45
gt_roidbMethod · 0.45
infer_shapeMethod · 0.45

Tested by

no test coverage detected