MCPcopy
hub / github.com/alibaba/EasyCV / run_test

Method run_test

tests/test_tools/test_export.py:52–108  ·  view source on GitHub ↗
(self,
                 CONFIG_FILE,
                 MODEL_TYPE,
                 img_size: int = 224,
                 **override_configs)

Source from the content-addressed store, hash-verified

50 super().tearDown()
51
52 def run_test(self,
53 CONFIG_FILE,
54 MODEL_TYPE,
55 img_size: int = 224,
56 **override_configs):
57 configs = BASIC_EXPORT_CONFIGS.copy()
58 configs['config_file'] = CONFIG_FILE
59
60 configs.update(override_configs)
61
62 cmd = build_cmd(configs, MODEL_TYPE)
63 logging.info(f'Export with commands: {cmd}')
64 run_in_subprocess(cmd)
65
66 cfg = mmcv_config_fromfile(configs['config_file'])
67 cfg = rebuild_config(cfg, configs['user_config_params'])
68
69 if hasattr(cfg.model, 'pretrained'):
70 cfg.model.pretrained = False
71
72 torch_model = build_model(cfg.model).eval()
73 if 'checkpoint' in override_configs:
74 load_checkpoint(
75 torch_model,
76 override_configs['checkpoint'],
77 strict=False,
78 logger=logging.getLogger())
79 session = onnxruntime.InferenceSession(configs['output_filename'] +
80 '.onnx')
81 input_tensor = torch.randn((1, 3, img_size, img_size))
82
83 torch_output = torch_model(input_tensor, mode='test')['prob']
84
85 onnx_output = session.run(
86 [session.get_outputs()[0].name],
87 {session.get_inputs()[0].name: np.array(input_tensor)})
88 if isinstance(onnx_output, list):
89 onnx_output = onnx_output[0]
90
91 onnx_output = torch.tensor(onnx_output)
92
93 is_same_shape = torch_output.shape == onnx_output.shape
94
95 self.assertTrue(
96 is_same_shape,
97 f'The shapes of the two outputs are mismatch, got {torch_output.shape} and {onnx_output.shape}'
98 )
99 is_allclose = torch.allclose(torch_output, onnx_output)
100
101 torch_out_minmax = f'{float(torch_output.min())}~{float(torch_output.max())}'
102 onnx_out_minmax = f'{float(onnx_output.min())}~{float(onnx_output.max())}'
103
104 info_msg = f'got avg: {float(torch_output.mean())} and {float(onnx_output.mean())},'
105 info_msg += f' and range: {torch_out_minmax} and {onnx_out_minmax}'
106 self.assertTrue(
107 is_allclose,
108 f'The values between the two outputs are mismatch, {info_msg}')
109

Callers 4

test_inceptionv3Method · 0.95
test_inceptionv4Method · 0.95
test_resnext50Method · 0.95
test_mobilenetv2Method · 0.95

Calls 9

run_in_subprocessFunction · 0.90
mmcv_config_fromfileFunction · 0.90
rebuild_configFunction · 0.90
build_modelFunction · 0.90
load_checkpointFunction · 0.90
build_cmdFunction · 0.85
copyMethod · 0.45
updateMethod · 0.45
infoMethod · 0.45

Tested by

no test coverage detected