MCPcopy
hub / github.com/msracver/Deformable-ConvNets / test_rcnn

Function test_rcnn

fpn/function/test_rcnn.py:28–77  ·  view source on GitHub ↗
(cfg, dataset, image_set, root_path, dataset_path,
              ctx, prefix, epoch,
              vis, ignore_cache, shuffle, has_rpn, proposal, thresh, logger=None, output_path=None)

Source from the content-addressed store, hash-verified

26
27
28def test_rcnn(cfg, dataset, image_set, root_path, dataset_path,
29 ctx, prefix, epoch,
30 vis, ignore_cache, shuffle, has_rpn, proposal, thresh, logger=None, output_path=None):
31 if not logger:
32 assert False, 'require a logger'
33
34 # print cfg
35 pprint.pprint(cfg)
36 logger.info('testing cfg:{}\n'.format(pprint.pformat(cfg)))
37
38 # load symbol and testing data
39 if has_rpn:
40 sym_instance = eval(cfg.symbol + '.' + cfg.symbol)()
41 sym = sym_instance.get_symbol(cfg, is_train=False)
42 imdb = eval(dataset)(image_set, root_path, dataset_path, result_path=output_path)
43 roidb = imdb.gt_roidb()
44 else:
45 sym_instance = eval(cfg.symbol + '.' + cfg.symbol)()
46 sym = sym_instance.get_symbol_rcnn(cfg, is_train=False)
47 imdb = eval(dataset)(image_set, root_path, dataset_path, result_path=output_path)
48 gt_roidb = imdb.gt_roidb()
49 roidb = eval('imdb.' + proposal + '_roidb')(gt_roidb)
50
51 # get test data iter
52 test_data = TestLoader(roidb, cfg, batch_size=len(ctx), shuffle=shuffle, has_rpn=has_rpn)
53
54 # load model
55 arg_params, aux_params = load_param(prefix, epoch, process=True)
56
57 # infer shape
58 data_shape_dict = dict(test_data.provide_data_single)
59 sym_instance.infer_shape(data_shape_dict)
60
61 sym_instance.check_parameter_shapes(arg_params, aux_params, data_shape_dict, is_train=False)
62
63 # decide maximum shape
64 data_names = [k[0] for k in test_data.provide_data_single]
65 label_names = None
66 max_data_shape = [[('data', (1, 3, max([v[0] for v in cfg.SCALES]), max([v[1] for v in cfg.SCALES])))]]
67 if not has_rpn:
68 max_data_shape.append(('rois', (cfg.TEST.PROPOSAL_POST_NMS_TOP_N + 30, 5)))
69
70 # create predictor
71 predictor = Predictor(sym, data_names, label_names,
72 context=ctx, max_data_shapes=max_data_shape,
73 provide_data=test_data.provide_data, provide_label=test_data.provide_label,
74 arg_params=arg_params, aux_params=aux_params)
75
76 # start detection
77 pred_eval(predictor, test_data, imdb, cfg, vis=vis, ignore_cache=ignore_cache, thresh=thresh, logger=logger)
78

Callers 1

mainFunction · 0.90

Calls 10

TestLoaderClass · 0.90
load_paramFunction · 0.90
PredictorClass · 0.90
pred_evalFunction · 0.90
infoMethod · 0.80
get_symbolMethod · 0.45
gt_roidbMethod · 0.45
get_symbol_rcnnMethod · 0.45
infer_shapeMethod · 0.45

Tested by

no test coverage detected