MCPcopy
hub / github.com/PaddlePaddle/PaddleOCR / export

Function export

ppocr/utils/export_model.py:419–547  ·  view source on GitHub ↗
(config, base_model=None, save_path=None)

Source from the content-addressed store, hash-verified

417
418
419def export(config, base_model=None, save_path=None):
420 if paddle.distributed.get_rank() != 0:
421 return
422 logger = get_logger()
423 # build post process
424 post_process_class = build_post_process(config["PostProcess"], config["Global"])
425
426 # build model
427 # for rec algorithm
428 if hasattr(post_process_class, "character"):
429 char_num = len(getattr(post_process_class, "character"))
430 if config["Architecture"]["algorithm"] in [
431 "Distillation",
432 ]: # distillation model
433 for key in config["Architecture"]["Models"]:
434 if (
435 config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
436 ): # multi head
437 out_channels_list = {}
438 if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
439 char_num = char_num - 2
440 if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
441 char_num = char_num - 3
442 out_channels_list["CTCLabelDecode"] = char_num
443 out_channels_list["SARLabelDecode"] = char_num + 2
444 out_channels_list["NRTRLabelDecode"] = char_num + 3
445 config["Architecture"]["Models"][key]["Head"][
446 "out_channels_list"
447 ] = out_channels_list
448 else:
449 config["Architecture"]["Models"][key]["Head"][
450 "out_channels"
451 ] = char_num
452 # just one final tensor needs to exported for inference
453 config["Architecture"]["Models"][key]["return_all_feats"] = False
454 elif config["Architecture"]["Head"]["name"] == "MultiHead": # multi head
455 out_channels_list = {}
456 char_num = len(getattr(post_process_class, "character"))
457 if config["PostProcess"]["name"] == "SARLabelDecode":
458 char_num = char_num - 2
459 if config["PostProcess"]["name"] == "NRTRLabelDecode":
460 char_num = char_num - 3
461 out_channels_list["CTCLabelDecode"] = char_num
462 out_channels_list["SARLabelDecode"] = char_num + 2
463 out_channels_list["NRTRLabelDecode"] = char_num + 3
464 config["Architecture"]["Head"]["out_channels_list"] = out_channels_list
465 else: # base rec model
466 config["Architecture"]["Head"]["out_channels"] = char_num
467
468 # for sr algorithm
469 if config["Architecture"]["model_type"] == "sr":
470 config["Architecture"]["Transform"]["infer_mode"] = True
471
472 # for latexocr algorithm
473 if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
474 config["Architecture"]["Backbone"]["is_predict"] = True
475 config["Architecture"]["Backbone"]["is_export"] = True
476 config["Architecture"]["Head"]["is_export"] = True

Callers 2

trainFunction · 0.90
mainFunction · 0.90

Calls 9

get_loggerFunction · 0.90
build_post_processFunction · 0.90
build_modelFunction · 0.90
load_modelFunction · 0.90
convert_bnFunction · 0.85
dump_infer_configFunction · 0.85
export_single_modelFunction · 0.85
evalMethod · 0.80
getMethod · 0.65

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…