MCPcopy
hub / github.com/PaddlePaddle/PaddleOCR / train

Function train

tools/program.py:200–706  ·  view source on GitHub ↗
(
    config,
    train_dataloader,
    valid_dataloader,
    device,
    model,
    loss_class,
    optimizer,
    lr_scheduler,
    post_process_class,
    eval_class,
    pre_best_model_dict,
    logger,
    step_pre_epoch,
    log_writer=None,
    scaler=None,
    amp_level="O2",
    amp_custom_black_list=[],
    amp_custom_white_list=[],
    amp_dtype="float16",
    wd_scheduler=None,
    ema=None,
)

Source from the content-addressed store, hash-verified

198
199
200def train(
201 config,
202 train_dataloader,
203 valid_dataloader,
204 device,
205 model,
206 loss_class,
207 optimizer,
208 lr_scheduler,
209 post_process_class,
210 eval_class,
211 pre_best_model_dict,
212 logger,
213 step_pre_epoch,
214 log_writer=None,
215 scaler=None,
216 amp_level="O2",
217 amp_custom_black_list=[],
218 amp_custom_white_list=[],
219 amp_dtype="float16",
220 wd_scheduler=None,
221 ema=None,
222):
223 cal_metric_during_train = config["Global"].get("cal_metric_during_train", False)
224 calc_epoch_interval = config["Global"].get("calc_epoch_interval", 1)
225 log_smooth_window = config["Global"]["log_smooth_window"]
226 epoch_num = config["Global"]["epoch_num"]
227 print_batch_step = config["Global"]["print_batch_step"]
228 eval_batch_step = config["Global"]["eval_batch_step"]
229 eval_batch_epoch = config["Global"].get("eval_batch_epoch", None)
230 profiler_options = config["profiler_options"]
231 print_mem_info = config["Global"].get("print_mem_info", True)
232 uniform_output_enabled = config["Global"].get("uniform_output_enabled", False)
233
234 global_step = 0
235 if "global_step" in pre_best_model_dict:
236 global_step = pre_best_model_dict["global_step"]
237 start_eval_step = 0
238 if isinstance(eval_batch_step, list) and len(eval_batch_step) >= 2:
239 start_eval_step = eval_batch_step[0] if not eval_batch_epoch else 0
240 eval_batch_step = (
241 eval_batch_step[1]
242 if not eval_batch_epoch
243 else step_pre_epoch * eval_batch_epoch
244 )
245 if len(valid_dataloader) == 0:
246 logger.info(
247 "No Images in eval dataset, evaluation during training "
248 "will be disabled"
249 )
250 start_eval_step = 1e111
251 logger.info(
252 "During the training process, after the {}th iteration, "
253 "an evaluation is run every {} iterations".format(
254 start_eval_step, eval_batch_step
255 )
256 )
257 save_epoch_step = config["Global"]["save_epoch_step"]

Callers

nothing calls this directly

Calls 15

updateMethod · 0.95
updateMethod · 0.95
getMethod · 0.95
logMethod · 0.95
TrainingStatsClass · 0.90
AverageMeterClass · 0.90
exportFunction · 0.90
save_modelFunction · 0.90
to_float32Function · 0.85
formatMethod · 0.80
trainMethod · 0.80
reset_data_linesMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…