MCPcopy Index your code
hub / github.com/modelscope/FunASR / _onnx_opt_for_encdec

Function _onnx_opt_for_encdec

funasr/utils/export_utils.py:306–375  ·  view source on GitHub ↗

Internal: onnx opt for encdec. Args: model: Model instance or model name. path: TODO. enable_fp16: TODO.

(model, path, enable_fp16)

Source from the content-addressed store, hash-verified

304
305
306def _onnx_opt_for_encdec(model, path, enable_fp16):
307
308 # Get input data
309 # TODO: better to use real data
310 """Internal: onnx opt for encdec.
311
312 Args:
313 model: Model instance or model name.
314 path: TODO.
315 enable_fp16: TODO.
316 """
317 input_data = model.export_dummy_inputs()
318
319 if isinstance(input_data, torch.Tensor):
320 input_data = input_data.cuda()
321 else:
322 input_data = tuple([i.cuda() for i in input_data])
323
324 # Get input data for decoder module
325 decoder_inputs = list()
326
327 def get_input_hook(m, x):
328 """Get input hook.
329
330 Args:
331 m: TODO.
332 x: TODO.
333 """
334 decoder_inputs.extend(list(x))
335
336 hook = model.decoder.register_forward_pre_hook(get_input_hook)
337 model = model.cuda()
338 model(*input_data)
339 hook.remove()
340
341 # Prevent FP16 overflow
342 if enable_fp16:
343 _rescale_encoder_model(model, input_data)
344
345 fp32_model_path = f"{path}/{model.export_name}_hook.onnx"
346 print("*" * 50)
347 print(f"[_onnx_opt_for_encdec(fp32)]: {fp32_model_path}\n\n")
348 if not os.path.exists(fp32_model_path):
349
350 torch.onnx.export(
351 model,
352 input_data,
353 fp32_model_path,
354 verbose=False,
355 do_constant_folding=True,
356 opset_version=13,
357 input_names=model.export_input_names(),
358 output_names=model.export_output_names(),
359 dynamic_axes=model.export_dynamic_axes(),
360 )
361
362 # fp32 to fp16
363 fp16_model_path = f"{path}/{model.export_name}_hook_fp16.onnx"

Callers 1

exportFunction · 0.85

Calls 6

_rescale_encoder_modelFunction · 0.85
export_input_namesMethod · 0.80
export_output_namesMethod · 0.80
export_dynamic_axesMethod · 0.80
export_dummy_inputsMethod · 0.45
exportMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…