Internal: onnx. Args: model: Model instance or model name. data_in: Input data (audio samples, file paths, or text). quantize: TODO. opset_version: TODO. export_dir: TODO. **kwargs: Additional keyword arguments.
(
model,
data_in=None,
quantize: bool = False,
opset_version: int = 14,
export_dir: str = None,
**kwargs,
)
| 64 | |
| 65 | |
| 66 | def _onnx( |
| 67 | model, |
| 68 | data_in=None, |
| 69 | quantize: bool = False, |
| 70 | opset_version: int = 14, |
| 71 | export_dir: str = None, |
| 72 | **kwargs, |
| 73 | ): |
| 74 | |
| 75 | """Internal: onnx. |
| 76 | |
| 77 | Args: |
| 78 | model: Model instance or model name. |
| 79 | data_in: Input data (audio samples, file paths, or text). |
| 80 | quantize: TODO. |
| 81 | opset_version: TODO. |
| 82 | export_dir: TODO. |
| 83 | **kwargs: Additional keyword arguments. |
| 84 | """ |
| 85 | device = kwargs.get("device", "cpu") |
| 86 | dummy_input = model.export_dummy_inputs() |
| 87 | |
| 88 | if isinstance(dummy_input, torch.Tensor): |
| 89 | dummy_input = dummy_input.to(device) |
| 90 | else: |
| 91 | dummy_input = tuple([input.to(device) for input in dummy_input]) |
| 92 | |
| 93 | verbose = kwargs.get("verbose", False) |
| 94 | |
| 95 | if isinstance(model.export_name, str): |
| 96 | export_name = model.export_name + ".onnx" |
| 97 | else: |
| 98 | export_name = model.export_name() |
| 99 | model_path = os.path.join(export_dir, export_name) |
| 100 | torch.onnx.export( |
| 101 | model, |
| 102 | dummy_input, |
| 103 | model_path, |
| 104 | verbose=verbose, |
| 105 | do_constant_folding=True, |
| 106 | opset_version=opset_version, |
| 107 | input_names=model.export_input_names(), |
| 108 | output_names=model.export_output_names(), |
| 109 | dynamic_axes=model.export_dynamic_axes(), |
| 110 | ) |
| 111 | |
| 112 | if quantize: |
| 113 | try: |
| 114 | from onnxruntime.quantization import QuantType, quantize_dynamic |
| 115 | import onnx |
| 116 | except: |
| 117 | raise RuntimeError( |
| 118 | "You are quantizing the onnx model, please install onnxruntime first. via \n`pip install onnx`\n`pip install onnxruntime`." |
| 119 | ) |
| 120 | |
| 121 | quant_model_path = model_path.replace(".onnx", "_quant.onnx") |
| 122 | onnx_model = onnx.load(model_path) |
| 123 | nodes = [n.name for n in onnx_model.graph.node] |
no test coverage detected
searching dependent graphs…