MCPcopy Index your code
hub / github.com/modelscope/FunASR / _bladedisc_opt_for_encdec

Function _bladedisc_opt_for_encdec

funasr/utils/export_utils.py:262–303  ·  view source on GitHub ↗

Internal: bladedisc opt for encdec. Args: model: Model instance or model name. path: TODO. enable_fp16: TODO.

(model, path, enable_fp16)

Source from the content-addressed store, hash-verified

260
261
262def _bladedisc_opt_for_encdec(model, path, enable_fp16):
263 # Get input data
264 # TODO: better to use real data
265 """Internal: bladedisc opt for encdec.
266
267 Args:
268 model: Model instance or model name.
269 path: TODO.
270 enable_fp16: TODO.
271 """
272 input_data = model.export_dummy_inputs()
273 if isinstance(input_data, torch.Tensor):
274 input_data = input_data.cuda()
275 else:
276 input_data = tuple([i.cuda() for i in input_data])
277
278 # Get input data for decoder module
279 decoder_inputs = list()
280
281 def get_input_hook(m, x):
282 """Get input hook.
283
284 Args:
285 m: TODO.
286 x: TODO.
287 """
288 decoder_inputs.extend(list(x))
289
290 hook = model.decoder.register_forward_pre_hook(get_input_hook)
291 model = model.cuda()
292 model(*input_data)
293 hook.remove()
294
295 # Prevent FP16 overflow
296 if enable_fp16:
297 _rescale_encoder_model(model, input_data)
298
299 # Export and optimize encoder/decoder modules
300 model.encoder = _bladedisc_opt(model.encoder, input_data[:2])
301 model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs))
302 model_script = torch.jit.trace(model, input_data)
303 model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript"))
304
305
306def _onnx_opt_for_encdec(model, path, enable_fp16):

Callers 1

exportFunction · 0.85

Calls 3

_rescale_encoder_modelFunction · 0.85
_bladedisc_optFunction · 0.85
export_dummy_inputsMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…