(config, base_model=None, save_path=None)
| 417 | |
| 418 | |
| 419 | def export(config, base_model=None, save_path=None): |
| 420 | if paddle.distributed.get_rank() != 0: |
| 421 | return |
| 422 | logger = get_logger() |
| 423 | # build post process |
| 424 | post_process_class = build_post_process(config["PostProcess"], config["Global"]) |
| 425 | |
| 426 | # build model |
| 427 | # for rec algorithm |
| 428 | if hasattr(post_process_class, "character"): |
| 429 | char_num = len(getattr(post_process_class, "character")) |
| 430 | if config["Architecture"]["algorithm"] in [ |
| 431 | "Distillation", |
| 432 | ]: # distillation model |
| 433 | for key in config["Architecture"]["Models"]: |
| 434 | if ( |
| 435 | config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead" |
| 436 | ): # multi head |
| 437 | out_channels_list = {} |
| 438 | if config["PostProcess"]["name"] == "DistillationSARLabelDecode": |
| 439 | char_num = char_num - 2 |
| 440 | if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode": |
| 441 | char_num = char_num - 3 |
| 442 | out_channels_list["CTCLabelDecode"] = char_num |
| 443 | out_channels_list["SARLabelDecode"] = char_num + 2 |
| 444 | out_channels_list["NRTRLabelDecode"] = char_num + 3 |
| 445 | config["Architecture"]["Models"][key]["Head"][ |
| 446 | "out_channels_list" |
| 447 | ] = out_channels_list |
| 448 | else: |
| 449 | config["Architecture"]["Models"][key]["Head"][ |
| 450 | "out_channels" |
| 451 | ] = char_num |
| 452 | # just one final tensor needs to exported for inference |
| 453 | config["Architecture"]["Models"][key]["return_all_feats"] = False |
| 454 | elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head |
| 455 | out_channels_list = {} |
| 456 | char_num = len(getattr(post_process_class, "character")) |
| 457 | if config["PostProcess"]["name"] == "SARLabelDecode": |
| 458 | char_num = char_num - 2 |
| 459 | if config["PostProcess"]["name"] == "NRTRLabelDecode": |
| 460 | char_num = char_num - 3 |
| 461 | out_channels_list["CTCLabelDecode"] = char_num |
| 462 | out_channels_list["SARLabelDecode"] = char_num + 2 |
| 463 | out_channels_list["NRTRLabelDecode"] = char_num + 3 |
| 464 | config["Architecture"]["Head"]["out_channels_list"] = out_channels_list |
| 465 | else: # base rec model |
| 466 | config["Architecture"]["Head"]["out_channels"] = char_num |
| 467 | |
| 468 | # for sr algorithm |
| 469 | if config["Architecture"]["model_type"] == "sr": |
| 470 | config["Architecture"]["Transform"]["infer_mode"] = True |
| 471 | |
| 472 | # for latexocr algorithm |
| 473 | if config["Architecture"].get("algorithm") in ["LaTeXOCR"]: |
| 474 | config["Architecture"]["Backbone"]["is_predict"] = True |
| 475 | config["Architecture"]["Backbone"]["is_export"] = True |
| 476 | config["Architecture"]["Head"]["is_export"] = True |
no test coverage detected
searching dependent graphs…