| 362 | |
| 363 | |
| 364 | def export_single_model( |
| 365 | model, |
| 366 | arch_config, |
| 367 | save_path, |
| 368 | logger, |
| 369 | yaml_path, |
| 370 | config, |
| 371 | input_shape=None, |
| 372 | quanter=None, |
| 373 | ): |
| 374 | |
| 375 | model = dynamic_to_static(model, arch_config, logger, input_shape) |
| 376 | |
| 377 | if quanter is None: |
| 378 | try: |
| 379 | import encryption # Attempt to import the encryption module for AIStudio's encryption model |
| 380 | except ( |
| 381 | ModuleNotFoundError |
| 382 | ): # Encryption is not needed if the module cannot be imported |
| 383 | print("Skipping import of the encryption module") |
| 384 | paddle_version = version.parse(paddle.__version__) |
| 385 | if config["Global"].get("export_with_pir", True): |
| 386 | assert ( |
| 387 | paddle_version >= version.parse("3.0.0b2") |
| 388 | or paddle_version == version.parse("0.0.0") |
| 389 | ) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"] |
| 390 | paddle.jit.save(model, save_path) |
| 391 | else: |
| 392 | if paddle_version >= version.parse( |
| 393 | "3.0.0b2" |
| 394 | ) or paddle_version == version.parse("0.0.0"): |
| 395 | model.forward.rollback() |
| 396 | with paddle.pir_utils.OldIrGuard(): |
| 397 | model = dynamic_to_static(model, arch_config, logger, input_shape) |
| 398 | paddle.jit.save(model, save_path) |
| 399 | else: |
| 400 | paddle.jit.save(model, save_path) |
| 401 | else: |
| 402 | quanter.save_quantized_model(model, save_path) |
| 403 | logger.info("inference model is saved to {}".format(save_path)) |
| 404 | return |
| 405 | |
| 406 | |
| 407 | def convert_bn(model): |