(config, device, logger, vdl_writer)
| 44 | |
| 45 | |
| 46 | def main(config, device, logger, vdl_writer): |
| 47 | # init dist environment |
| 48 | if config["Global"]["distributed"]: |
| 49 | dist.init_parallel_env() |
| 50 | |
| 51 | global_config = config["Global"] |
| 52 | |
| 53 | # build dataloader |
| 54 | # NOTE: Do NOT pass seed here. The seed parameter in build_dataloader is used |
| 55 | # as epoch number by set_epoch_as_seed (for adaptive shrink_ratio), not as |
| 56 | # random seed. First construction should use epoch=0 (i.e., seed=None). |
| 57 | # The epoch loop in program.train() handles subsequent updates via |
| 58 | # reset_data_lines(seed=epoch). |
| 59 | set_signal_handlers() |
| 60 | train_dataloader = build_dataloader(config, "Train", device, logger) |
| 61 | if len(train_dataloader) == 0: |
| 62 | logger.error( |
| 63 | "No Images in train dataset, please ensure\n" |
| 64 | + "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n" |
| 65 | + "\t2. The annotation file and path in the configuration file are provided normally." |
| 66 | ) |
| 67 | return |
| 68 | |
| 69 | if config["Eval"]: |
| 70 | valid_dataloader = build_dataloader(config, "Eval", device, logger) |
| 71 | else: |
| 72 | valid_dataloader = None |
| 73 | step_pre_epoch = len(train_dataloader) |
| 74 | |
| 75 | # build post process |
| 76 | post_process_class = build_post_process(config["PostProcess"], global_config) |
| 77 | |
| 78 | # build model |
| 79 | # for rec algorithm |
| 80 | if hasattr(post_process_class, "character"): |
| 81 | char_num = len(getattr(post_process_class, "character")) |
| 82 | if config["Architecture"]["algorithm"] in [ |
| 83 | "Distillation", |
| 84 | ]: # distillation model |
| 85 | for key in config["Architecture"]["Models"]: |
| 86 | if ( |
| 87 | config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead" |
| 88 | ): # for multi head |
| 89 | if config["PostProcess"]["name"] == "DistillationSARLabelDecode": |
| 90 | char_num = char_num - 2 |
| 91 | if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode": |
| 92 | char_num = char_num - 3 |
| 93 | out_channels_list = {} |
| 94 | out_channels_list["CTCLabelDecode"] = char_num |
| 95 | # update SARLoss params |
| 96 | if ( |
| 97 | list(config["Loss"]["loss_config_list"][-1].keys())[0] |
| 98 | == "DistillationSARLoss" |
| 99 | ): |
| 100 | config["Loss"]["loss_config_list"][-1]["DistillationSARLoss"][ |
| 101 | "ignore_index" |
| 102 | ] = (char_num + 1) |
| 103 | out_channels_list["SARLabelDecode"] = char_num + 2 |
no test coverage detected
searching dependent graphs…