MCPcopy Index your code
hub / github.com/modelscope/FunASR / _onnx

Function _onnx

funasr/utils/export_utils.py:66–136  ·  view source on GitHub ↗

Internal: onnx. Args: model: Model instance or model name. data_in: Input data (audio samples, file paths, or text). quantize: TODO. opset_version: TODO. export_dir: TODO. **kwargs: Additional keyword arguments.

(
    model,
    data_in=None,
    quantize: bool = False,
    opset_version: int = 14,
    export_dir: str = None,
    **kwargs,
)

Source from the content-addressed store, hash-verified

64
65
66def _onnx(
67 model,
68 data_in=None,
69 quantize: bool = False,
70 opset_version: int = 14,
71 export_dir: str = None,
72 **kwargs,
73):
74
75 """Internal: onnx.
76
77 Args:
78 model: Model instance or model name.
79 data_in: Input data (audio samples, file paths, or text).
80 quantize: TODO.
81 opset_version: TODO.
82 export_dir: TODO.
83 **kwargs: Additional keyword arguments.
84 """
85 device = kwargs.get("device", "cpu")
86 dummy_input = model.export_dummy_inputs()
87
88 if isinstance(dummy_input, torch.Tensor):
89 dummy_input = dummy_input.to(device)
90 else:
91 dummy_input = tuple([input.to(device) for input in dummy_input])
92
93 verbose = kwargs.get("verbose", False)
94
95 if isinstance(model.export_name, str):
96 export_name = model.export_name + ".onnx"
97 else:
98 export_name = model.export_name()
99 model_path = os.path.join(export_dir, export_name)
100 torch.onnx.export(
101 model,
102 dummy_input,
103 model_path,
104 verbose=verbose,
105 do_constant_folding=True,
106 opset_version=opset_version,
107 input_names=model.export_input_names(),
108 output_names=model.export_output_names(),
109 dynamic_axes=model.export_dynamic_axes(),
110 )
111
112 if quantize:
113 try:
114 from onnxruntime.quantization import QuantType, quantize_dynamic
115 import onnx
116 except:
117 raise RuntimeError(
118 "You are quantizing the onnx model, please install onnxruntime first. via \n`pip install onnx`\n`pip install onnxruntime`."
119 )
120
121 quant_model_path = model_path.replace(".onnx", "_quant.onnx")
122 onnx_model = onnx.load(model_path)
123 nodes = [n.name for n in onnx_model.graph.node]

Callers 1

exportFunction · 0.85

Calls 6

export_nameMethod · 0.80
export_input_namesMethod · 0.80
export_output_namesMethod · 0.80
export_dynamic_axesMethod · 0.80
export_dummy_inputsMethod · 0.45
exportMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…