Internal: bladedisc opt for encdec. Args: model: Model instance or model name. path: TODO. enable_fp16: TODO.
(model, path, enable_fp16)
| 260 | |
| 261 | |
| 262 | def _bladedisc_opt_for_encdec(model, path, enable_fp16): |
| 263 | # Get input data |
| 264 | # TODO: better to use real data |
| 265 | """Internal: bladedisc opt for encdec. |
| 266 | |
| 267 | Args: |
| 268 | model: Model instance or model name. |
| 269 | path: TODO. |
| 270 | enable_fp16: TODO. |
| 271 | """ |
| 272 | input_data = model.export_dummy_inputs() |
| 273 | if isinstance(input_data, torch.Tensor): |
| 274 | input_data = input_data.cuda() |
| 275 | else: |
| 276 | input_data = tuple([i.cuda() for i in input_data]) |
| 277 | |
| 278 | # Get input data for decoder module |
| 279 | decoder_inputs = list() |
| 280 | |
| 281 | def get_input_hook(m, x): |
| 282 | """Get input hook. |
| 283 | |
| 284 | Args: |
| 285 | m: TODO. |
| 286 | x: TODO. |
| 287 | """ |
| 288 | decoder_inputs.extend(list(x)) |
| 289 | |
| 290 | hook = model.decoder.register_forward_pre_hook(get_input_hook) |
| 291 | model = model.cuda() |
| 292 | model(*input_data) |
| 293 | hook.remove() |
| 294 | |
| 295 | # Prevent FP16 overflow |
| 296 | if enable_fp16: |
| 297 | _rescale_encoder_model(model, input_data) |
| 298 | |
| 299 | # Export and optimize encoder/decoder modules |
| 300 | model.encoder = _bladedisc_opt(model.encoder, input_data[:2]) |
| 301 | model.decoder = _bladedisc_opt(model.decoder, tuple(decoder_inputs)) |
| 302 | model_script = torch.jit.trace(model, input_data) |
| 303 | model_script.save(os.path.join(path, f"{model.export_name}_blade.torchscript")) |
| 304 | |
| 305 | |
| 306 | def _onnx_opt_for_encdec(model, path, enable_fp16): |
no test coverage detected
searching dependent graphs…