MCPcopy
hub / github.com/PaddlePaddle/PaddleOCR / main

Function main

deploy/slim/quantization/quant.py:89–221  ·  view source on GitHub ↗
(config, device, logger, vdl_writer)

Source from the content-addressed store, hash-verified

87
88
89def main(config, device, logger, vdl_writer):
90 # init dist environment
91 if config["Global"]["distributed"]:
92 dist.init_parallel_env()
93
94 global_config = config["Global"]
95
96 # build dataloader
97 set_signal_handlers()
98 train_dataloader = build_dataloader(config, "Train", device, logger)
99 if config["Eval"]:
100 valid_dataloader = build_dataloader(config, "Eval", device, logger)
101 else:
102 valid_dataloader = None
103
104 # build post process
105 post_process_class = build_post_process(config["PostProcess"], global_config)
106
107 # build model
108 # for rec algorithm
109 if hasattr(post_process_class, "character"):
110 char_num = len(getattr(post_process_class, "character"))
111 if config["Architecture"]["algorithm"] in [
112 "Distillation",
113 ]: # distillation model
114 for key in config["Architecture"]["Models"]:
115 if (
116 config["Architecture"]["Models"][key]["Head"]["name"] == "MultiHead"
117 ): # for multi head
118 if config["PostProcess"]["name"] == "DistillationSARLabelDecode":
119 char_num = char_num - 2
120 # update SARLoss params
121 assert (
122 list(config["Loss"]["loss_config_list"][-1].keys())[0]
123 == "DistillationSARLoss"
124 )
125 config["Loss"]["loss_config_list"][-1]["DistillationSARLoss"][
126 "ignore_index"
127 ] = (char_num + 1)
128 out_channels_list = {}
129 out_channels_list["CTCLabelDecode"] = char_num
130 out_channels_list["SARLabelDecode"] = char_num + 2
131 config["Architecture"]["Models"][key]["Head"][
132 "out_channels_list"
133 ] = out_channels_list
134 else:
135 config["Architecture"]["Models"][key]["Head"][
136 "out_channels"
137 ] = char_num
138 elif config["Architecture"]["Head"]["name"] == "MultiHead": # for multi head
139 if config["PostProcess"]["name"] == "SARLabelDecode":
140 char_num = char_num - 2
141 # update SARLoss params
142 assert list(config["Loss"]["loss_config_list"][1].keys())[0] == "SARLoss"
143 if config["Loss"]["loss_config_list"][1]["SARLoss"] is None:
144 config["Loss"]["loss_config_list"][1]["SARLoss"] = {
145 "ignore_index": char_num + 1
146 }

Callers 1

quant.pyFile · 0.70

Calls 11

set_signal_handlersFunction · 0.90
build_dataloaderFunction · 0.90
build_post_processFunction · 0.90
build_modelFunction · 0.90
load_modelFunction · 0.90
build_lossFunction · 0.90
build_optimizerFunction · 0.90
build_metricFunction · 0.90
formatMethod · 0.80
trainMethod · 0.80
getMethod · 0.65

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…