()
| 33 | |
| 34 | |
| 35 | def main(): |
| 36 | global_config = config["Global"] |
| 37 | # build dataloader |
| 38 | set_signal_handlers() |
| 39 | valid_dataloader = build_dataloader(config, "Eval", device, logger) |
| 40 | |
| 41 | # build post process |
| 42 | post_process_class = build_post_process(config["PostProcess"], global_config) |
| 43 | |
| 44 | # build model |
| 45 | # for rec algorithm |
| 46 | if hasattr(post_process_class, "character"): |
| 47 | char_num = len(getattr(post_process_class, "character")) |
| 48 | if config["Architecture"]["algorithm"] in [ |
| 49 | "Distillation", |
| 50 | ]: # distillation model |
| 51 | for key in config["Architecture"]["Models"]: |
| 52 | if ( |
| 53 | config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead" |
| 54 | ): # for multi head |
| 55 | out_channels_list = {} |
| 56 | if config["PostProcess"]["name"] == "DistillationSARLabelDecode": |
| 57 | char_num = char_num - 2 |
| 58 | if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode": |
| 59 | char_num = char_num - 3 |
| 60 | out_channels_list["CTCLabelDecode"] = char_num |
| 61 | out_channels_list["SARLabelDecode"] = char_num + 2 |
| 62 | out_channels_list["NRTRLabelDecode"] = char_num + 3 |
| 63 | config["Architecture"]["Models"][key]["Head"][ |
| 64 | "out_channels_list" |
| 65 | ] = out_channels_list |
| 66 | else: |
| 67 | config["Architecture"]["Models"][key]["Head"][ |
| 68 | "out_channels" |
| 69 | ] = char_num |
| 70 | elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head |
| 71 | out_channels_list = {} |
| 72 | if config["PostProcess"]["name"] == "SARLabelDecode": |
| 73 | char_num = char_num - 2 |
| 74 | if config["PostProcess"]["name"] == "NRTRLabelDecode": |
| 75 | char_num = char_num - 3 |
| 76 | out_channels_list["CTCLabelDecode"] = char_num |
| 77 | out_channels_list["SARLabelDecode"] = char_num + 2 |
| 78 | out_channels_list["NRTRLabelDecode"] = char_num + 3 |
| 79 | config["Architecture"]["Head"]["out_channels_list"] = out_channels_list |
| 80 | else: # base rec model |
| 81 | config["Architecture"]["Head"]["out_channels"] = char_num |
| 82 | |
| 83 | model = build_model(config["Architecture"]) |
| 84 | extra_input_models = [ |
| 85 | "SRN", |
| 86 | "NRTR", |
| 87 | "SAR", |
| 88 | "SEED", |
| 89 | "SVTR", |
| 90 | "SVTR_LCNet", |
| 91 | "VisionLAN", |
| 92 | "RobustScanner", |
no test coverage detected
searching dependent graphs…