(config, device, logger, vdl_writer)
| 87 | |
| 88 | |
| 89 | def main(config, device, logger, vdl_writer): |
| 90 | # init dist environment |
| 91 | if config["Global"]["distributed"]: |
| 92 | dist.init_parallel_env() |
| 93 | |
| 94 | global_config = config["Global"] |
| 95 | |
| 96 | # build dataloader |
| 97 | set_signal_handlers() |
| 98 | train_dataloader = build_dataloader(config, "Train", device, logger) |
| 99 | if config["Eval"]: |
| 100 | valid_dataloader = build_dataloader(config, "Eval", device, logger) |
| 101 | else: |
| 102 | valid_dataloader = None |
| 103 | |
| 104 | # build post process |
| 105 | post_process_class = build_post_process(config["PostProcess"], global_config) |
| 106 | |
| 107 | # build model |
| 108 | # for rec algorithm |
| 109 | if hasattr(post_process_class, "character"): |
| 110 | char_num = len(getattr(post_process_class, "character")) |
| 111 | if config["Architecture"]["algorithm"] in [ |
| 112 | "Distillation", |
| 113 | ]: # distillation model |
| 114 | for key in config["Architecture"]["Models"]: |
| 115 | if ( |
| 116 | config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead" |
| 117 | ): # for multi head |
| 118 | if config["PostProcess"]["name"] == "DistillationSARLabelDecode": |
| 119 | char_num = char_num - 2 |
| 120 | # update SARLoss params |
| 121 | assert ( |
| 122 | list(config["Loss"]["loss_config_list"][-1].keys())[0] |
| 123 | == "DistillationSARLoss" |
| 124 | ) |
| 125 | config["Loss"]["loss_config_list"][-1]["DistillationSARLoss"][ |
| 126 | "ignore_index" |
| 127 | ] = (char_num + 1) |
| 128 | out_channels_list = {} |
| 129 | out_channels_list["CTCLabelDecode"] = char_num |
| 130 | out_channels_list["SARLabelDecode"] = char_num + 2 |
| 131 | config["Architecture"]["Models"][key]["Head"][ |
| 132 | "out_channels_list" |
| 133 | ] = out_channels_list |
| 134 | else: |
| 135 | config["Architecture"]["Models"][key]["Head"][ |
| 136 | "out_channels" |
| 137 | ] = char_num |
| 138 | elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head |
| 139 | if config["PostProcess"]["name"] == "SARLabelDecode": |
| 140 | char_num = char_num - 2 |
| 141 | # update SARLoss params |
| 142 | assert list(config["Loss"]["loss_config_list"][1].keys())[0] == "SARLoss" |
| 143 | if config["Loss"]["loss_config_list"][1]["SARLoss"] is None: |
| 144 | config["Loss"]["loss_config_list"][1]["SARLoss"] = { |
| 145 | "ignore_index": char_num + 1 |
| 146 | } |
no test coverage detected
searching dependent graphs…