| 7 | |
| 8 | |
| 9 | def export_detector(detector_onnx_save_path, |
| 10 | in_shape=[1, 3, 608, 800], |
| 11 | lang_list=["en"], |
| 12 | model_storage_directory=None, |
| 13 | user_network_directory=None, |
| 14 | download_enabled=True, |
| 15 | dynamic=True, |
| 16 | device="cpu", |
| 17 | quantize=True, |
| 18 | detector=True, |
| 19 | recognizer=True): |
| 20 | if dynamic is False: |
| 21 | print('WARNING: it is recommended to use -d dynamic flag when exporting onnx') |
| 22 | ocr_reader = easyocr.Reader(lang_list, |
| 23 | gpu=False if device == "cpu" else True, |
| 24 | detector=detector, |
| 25 | recognizer=detector, |
| 26 | quantize=quantize, |
| 27 | model_storage_directory=model_storage_directory, |
| 28 | user_network_directory=user_network_directory, |
| 29 | download_enabled=download_enabled) |
| 30 | |
| 31 | # exporting detector if selected |
| 32 | if detector: |
| 33 | dummy_input = torch.rand(in_shape) |
| 34 | dummy_input = dummy_input.to(device) |
| 35 | |
| 36 | # forward pass |
| 37 | with torch.no_grad(): |
| 38 | y_torch_out, feature_torch_out = ocr_reader.detector(dummy_input) |
| 39 | torch.onnx.export(ocr_reader.detector, |
| 40 | dummy_input, |
| 41 | detector_onnx_save_path, |
| 42 | export_params=True, |
| 43 | do_constant_folding=True, |
| 44 | opset_version=12, |
| 45 | # model's input names |
| 46 | input_names=['input'], |
| 47 | # model's output names, ignore the 2nd output |
| 48 | output_names=['output'], |
| 49 | # variable length axes |
| 50 | dynamic_axes={'input': {0: 'batch_size', 2: "height", 3: "width"}, |
| 51 | 'output': {0: 'batch_size', 1: "dim1", 2: "dim2"} |
| 52 | } if dynamic else None, |
| 53 | verbose=False) |
| 54 | |
| 55 | # verify exported onnx model |
| 56 | detector_onnx = onnx.load(detector_onnx_save_path) |
| 57 | onnx.checker.check_model(detector_onnx) |
| 58 | print(f"Model Inputs:\n {detector_onnx.graph.input}\n{'*'*80}") |
| 59 | print(f"Model Outputs:\n {detector_onnx.graph.output}\n{'*'*80}") |
| 60 | |
| 61 | # onnx inference validation |
| 62 | import onnxruntime |
| 63 | |
| 64 | ort_session = onnxruntime.InferenceSession(detector_onnx_save_path) |
| 65 | |
| 66 | def to_numpy(tensor): |