Internal: onnx opt for encdec. Args: model: Model instance or model name. path: TODO. enable_fp16: TODO.
(model, path, enable_fp16)
| 304 | |
| 305 | |
| 306 | def _onnx_opt_for_encdec(model, path, enable_fp16): |
| 307 | |
| 308 | # Get input data |
| 309 | # TODO: better to use real data |
| 310 | """Internal: onnx opt for encdec. |
| 311 | |
| 312 | Args: |
| 313 | model: Model instance or model name. |
| 314 | path: TODO. |
| 315 | enable_fp16: TODO. |
| 316 | """ |
| 317 | input_data = model.export_dummy_inputs() |
| 318 | |
| 319 | if isinstance(input_data, torch.Tensor): |
| 320 | input_data = input_data.cuda() |
| 321 | else: |
| 322 | input_data = tuple([i.cuda() for i in input_data]) |
| 323 | |
| 324 | # Get input data for decoder module |
| 325 | decoder_inputs = list() |
| 326 | |
| 327 | def get_input_hook(m, x): |
| 328 | """Get input hook. |
| 329 | |
| 330 | Args: |
| 331 | m: TODO. |
| 332 | x: TODO. |
| 333 | """ |
| 334 | decoder_inputs.extend(list(x)) |
| 335 | |
| 336 | hook = model.decoder.register_forward_pre_hook(get_input_hook) |
| 337 | model = model.cuda() |
| 338 | model(*input_data) |
| 339 | hook.remove() |
| 340 | |
| 341 | # Prevent FP16 overflow |
| 342 | if enable_fp16: |
| 343 | _rescale_encoder_model(model, input_data) |
| 344 | |
| 345 | fp32_model_path = f"{path}/{model.export_name}_hook.onnx" |
| 346 | print("*" * 50) |
| 347 | print(f"[_onnx_opt_for_encdec(fp32)]: {fp32_model_path}\n\n") |
| 348 | if not os.path.exists(fp32_model_path): |
| 349 | |
| 350 | torch.onnx.export( |
| 351 | model, |
| 352 | input_data, |
| 353 | fp32_model_path, |
| 354 | verbose=False, |
| 355 | do_constant_folding=True, |
| 356 | opset_version=13, |
| 357 | input_names=model.export_input_names(), |
| 358 | output_names=model.export_output_names(), |
| 359 | dynamic_axes=model.export_dynamic_axes(), |
| 360 | ) |
| 361 | |
| 362 | # fp32 to fp16 |
| 363 | fp16_model_path = f"{path}/{model.export_name}_hook_fp16.onnx" |
no test coverage detected
searching dependent graphs…