MCPcopy
hub / github.com/PaddlePaddle/PaddleOCR / train

Function train

test_tipc/supplementary/train.py:87–196  ·  view source on GitHub ↗
(config, scaler=None)

Source from the content-addressed store, hash-verified

85
86
87def train(config, scaler=None):
88 EPOCH = config["epoch"]
89 topk = config["topk"]
90
91 batch_size = config["TRAIN"]["batch_size"]
92 num_workers = config["TRAIN"]["num_workers"]
93 train_loader = build_dataloader(
94 "train", batch_size=batch_size, num_workers=num_workers
95 )
96
97 # build metric
98 metric_func = create_metric
99
100 # build model
101 # model = MobileNetV3_large_x0_5(class_dim=100)
102 model = build_model(config)
103
104 # build_optimizer
105 optimizer, lr_scheduler = create_optimizer(
106 config, parameter_list=model.parameters()
107 )
108
109 # load model
110 pre_best_model_dict = load_model(config, model, optimizer)
111 if len(pre_best_model_dict) > 0:
112 pre_str = "The metric of loaded metric as follows {}".format(
113 ", ".join(["{}: {}".format(k, v) for k, v in pre_best_model_dict.items()])
114 )
115 logger.info(pre_str)
116
117 # about slim prune and quant
118 if "quant_train" in config and config["quant_train"] is True:
119 quanter = QAT(config=quant_config, act_preprocess=PACT)
120 quanter.quantize(model)
121 elif "prune_train" in config and config["prune_train"] is True:
122 model = prune_model(model, [1, 3, 32, 32], 0.1)
123 else:
124 pass
125
126 # distribution
127 model.train()
128 model = paddle.DataParallel(model)
129 # build loss function
130 loss_func = build_loss(config)
131
132 data_num = len(train_loader)
133
134 best_acc = {}
135 for epoch in range(EPOCH):
136 st = time.time()
137 for idx, data in enumerate(train_loader):
138 img_batch, label = data
139 img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
140 label = paddle.unsqueeze(label, -1)
141
142 if scaler is not None:
143 with paddle.amp.auto_cast():
144 outs = model(img_batch)

Callers 1

train.pyFile · 0.70

Calls 12

build_dataloaderFunction · 0.90
build_modelFunction · 0.90
create_optimizerFunction · 0.90
load_modelFunction · 0.90
prune_modelFunction · 0.90
build_lossFunction · 0.90
formatMethod · 0.80
trainMethod · 0.80
backwardMethod · 0.80
stepMethod · 0.80
evalFunction · 0.70
save_modelFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…