MCPcopy
hub / github.com/PaddlePaddle/PaddleOCR / train_distill

Function train_distill

test_tipc/supplementary/train.py:199–313  ·  view source on GitHub ↗
(config, scaler=None)

Source from the content-addressed store, hash-verified

197
198
199def train_distill(config, scaler=None):
200 EPOCH = config["epoch"]
201 topk = config["topk"]
202
203 batch_size = config["TRAIN"]["batch_size"]
204 num_workers = config["TRAIN"]["num_workers"]
205 train_loader = build_dataloader(
206 "train", batch_size=batch_size, num_workers=num_workers
207 )
208
209 # build metric
210 metric_func = create_metric
211
212 # model = distillmv3_large_x0_5(class_dim=100)
213 model = build_model(config)
214
215 # pact quant train
216 if "quant_train" in config and config["quant_train"] is True:
217 quanter = QAT(config=quant_config, act_preprocess=PACT)
218 quanter.quantize(model)
219 elif "prune_train" in config and config["prune_train"] is True:
220 model = prune_model(model, [1, 3, 32, 32], 0.1)
221 else:
222 pass
223
224 # build_optimizer
225 optimizer, lr_scheduler = create_optimizer(
226 config, parameter_list=model.parameters()
227 )
228
229 # load model
230 pre_best_model_dict = load_model(config, model, optimizer)
231 if len(pre_best_model_dict) > 0:
232 pre_str = "The metric of loaded metric as follows {}".format(
233 ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
234 )
235 logger.info(pre_str)
236
237 model.train()
238 model = paddle.DataParallel(model)
239
240 # build loss function
241 loss_func_distill = LossDistill(model_name_list=["student", "student1"])
242 loss_func_dml = DMLLoss(model_name_pairs=["student", "student1"])
243 loss_func_js = KLJSLoss(mode="js")
244
245 data_num = len(train_loader)
246
247 best_acc = {}
248 for epoch in range(EPOCH):
249 st = time.time()
250 for idx, data in enumerate(train_loader):
251 img_batch, label = data
252 img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
253 label = paddle.unsqueeze(label, -1)
254 if scaler is not None:
255 with paddle.amp.auto_cast():
256 outs = model(img_batch)

Callers 1

train.pyFile · 0.85

Calls 14

build_dataloaderFunction · 0.90
build_modelFunction · 0.90
prune_modelFunction · 0.90
create_optimizerFunction · 0.90
load_modelFunction · 0.90
LossDistillClass · 0.90
DMLLossClass · 0.90
KLJSLossClass · 0.90
formatMethod · 0.80
trainMethod · 0.80
backwardMethod · 0.80
stepMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…