(self,
CONFIG_FILE,
MODEL_TYPE,
img_size: int = 224,
**override_configs)
| 50 | super().tearDown() |
| 51 | |
| 52 | def run_test(self, |
| 53 | CONFIG_FILE, |
| 54 | MODEL_TYPE, |
| 55 | img_size: int = 224, |
| 56 | **override_configs): |
| 57 | configs = BASIC_EXPORT_CONFIGS.copy() |
| 58 | configs['config_file'] = CONFIG_FILE |
| 59 | |
| 60 | configs.update(override_configs) |
| 61 | |
| 62 | cmd = build_cmd(configs, MODEL_TYPE) |
| 63 | logging.info(f'Export with commands: {cmd}') |
| 64 | run_in_subprocess(cmd) |
| 65 | |
| 66 | cfg = mmcv_config_fromfile(configs['config_file']) |
| 67 | cfg = rebuild_config(cfg, configs['user_config_params']) |
| 68 | |
| 69 | if hasattr(cfg.model, 'pretrained'): |
| 70 | cfg.model.pretrained = False |
| 71 | |
| 72 | torch_model = build_model(cfg.model).eval() |
| 73 | if 'checkpoint' in override_configs: |
| 74 | load_checkpoint( |
| 75 | torch_model, |
| 76 | override_configs['checkpoint'], |
| 77 | strict=False, |
| 78 | logger=logging.getLogger()) |
| 79 | session = onnxruntime.InferenceSession(configs['output_filename'] + |
| 80 | '.onnx') |
| 81 | input_tensor = torch.randn((1, 3, img_size, img_size)) |
| 82 | |
| 83 | torch_output = torch_model(input_tensor, mode='test')['prob'] |
| 84 | |
| 85 | onnx_output = session.run( |
| 86 | [session.get_outputs()[0].name], |
| 87 | {session.get_inputs()[0].name: np.array(input_tensor)}) |
| 88 | if isinstance(onnx_output, list): |
| 89 | onnx_output = onnx_output[0] |
| 90 | |
| 91 | onnx_output = torch.tensor(onnx_output) |
| 92 | |
| 93 | is_same_shape = torch_output.shape == onnx_output.shape |
| 94 | |
| 95 | self.assertTrue( |
| 96 | is_same_shape, |
| 97 | f'The shapes of the two outputs are mismatch, got {torch_output.shape} and {onnx_output.shape}' |
| 98 | ) |
| 99 | is_allclose = torch.allclose(torch_output, onnx_output) |
| 100 | |
| 101 | torch_out_minmax = f'{float(torch_output.min())}~{float(torch_output.max())}' |
| 102 | onnx_out_minmax = f'{float(onnx_output.min())}~{float(onnx_output.max())}' |
| 103 | |
| 104 | info_msg = f'got avg: {float(torch_output.mean())} and {float(onnx_output.mean())},' |
| 105 | info_msg += f' and range: {torch_out_minmax} and {onnx_out_minmax}' |
| 106 | self.assertTrue( |
| 107 | is_allclose, |
| 108 | f'The values between the two outputs are mismatch, {info_msg}') |
| 109 |
no test coverage detected