MCPcopy
hub / github.com/PaddlePaddle/PaddleOCR / main

Function main

tools/train.py:46–273  ·  view source on GitHub ↗
(config, device, logger, vdl_writer)

Source from the content-addressed store, hash-verified

44
45
46def main(config, device, logger, vdl_writer):
47 # init dist environment
48 if config["Global"]["distributed"]:
49 dist.init_parallel_env()
50
51 global_config = config["Global"]
52
53 # build dataloader
54 # NOTE: Do NOT pass seed here. The seed parameter in build_dataloader is used
55 # as epoch number by set_epoch_as_seed (for adaptive shrink_ratio), not as
56 # random seed. First construction should use epoch=0 (i.e., seed=None).
57 # The epoch loop in program.train() handles subsequent updates via
58 # reset_data_lines(seed=epoch).
59 set_signal_handlers()
60 train_dataloader = build_dataloader(config, "Train", device, logger)
61 if len(train_dataloader) == 0:
62 logger.error(
63 "No Images in train dataset, please ensure\n"
64 + "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
65 + "\t2. The annotation file and path in the configuration file are provided normally."
66 )
67 return
68
69 if config["Eval"]:
70 valid_dataloader = build_dataloader(config, "Eval", device, logger)
71 else:
72 valid_dataloader = None
73 step_pre_epoch = len(train_dataloader)
74
75 # build post process
76 post_process_class = build_post_process(config["PostProcess"], global_config)
77
78 # build model
79 # for rec algorithm
80 if hasattr(post_process_class, "character"):
81 char_num = len(getattr(post_process_class, "character"))
82 if config["Architecture"]["algorithm"] in [
83 "Distillation",
84 ]: # distillation model
85 for key in config["Architecture"]["Models"]:
86 if (
87 config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
88 ): # for multi head
89 if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
90 char_num = char_num - 2
91 if config["PostProcess"]["name"] == "DistillationNRTRLabelDecode":
92 char_num = char_num - 3
93 out_channels_list = {}
94 out_channels_list["CTCLabelDecode"] = char_num
95 # update SARLoss params
96 if (
97 list(config["Loss"]["loss_config_list"][-1].keys())[0]
98 == "DistillationSARLoss"
99 ):
100 config["Loss"]["loss_config_list"][-1]["DistillationSARLoss"][
101 "ignore_index"
102 ] = (char_num + 1)
103 out_channels_list["SARLabelDecode"] = char_num + 2

Callers 1

train.pyFile · 0.70

Calls 14

set_signal_handlersFunction · 0.90
build_dataloaderFunction · 0.90
build_post_processFunction · 0.90
build_modelFunction · 0.90
apply_to_staticFunction · 0.90
build_lossFunction · 0.90
build_optimizerFunction · 0.90
build_metricFunction · 0.90
ModelEMAClass · 0.90
load_modelFunction · 0.90
formatMethod · 0.80
trainMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…