(
config,
train_dataloader,
valid_dataloader,
device,
model,
loss_class,
optimizer,
lr_scheduler,
post_process_class,
eval_class,
pre_best_model_dict,
logger,
step_pre_epoch,
log_writer=None,
scaler=None,
amp_level="O2",
amp_custom_black_list=[],
amp_custom_white_list=[],
amp_dtype="float16",
wd_scheduler=None,
ema=None,
)
| 198 | |
| 199 | |
| 200 | def train( |
| 201 | config, |
| 202 | train_dataloader, |
| 203 | valid_dataloader, |
| 204 | device, |
| 205 | model, |
| 206 | loss_class, |
| 207 | optimizer, |
| 208 | lr_scheduler, |
| 209 | post_process_class, |
| 210 | eval_class, |
| 211 | pre_best_model_dict, |
| 212 | logger, |
| 213 | step_pre_epoch, |
| 214 | log_writer=None, |
| 215 | scaler=None, |
| 216 | amp_level="O2", |
| 217 | amp_custom_black_list=[], |
| 218 | amp_custom_white_list=[], |
| 219 | amp_dtype="float16", |
| 220 | wd_scheduler=None, |
| 221 | ema=None, |
| 222 | ): |
| 223 | cal_metric_during_train = config["Global"].get("cal_metric_during_train", False) |
| 224 | calc_epoch_interval = config["Global"].get("calc_epoch_interval", 1) |
| 225 | log_smooth_window = config["Global"]["log_smooth_window"] |
| 226 | epoch_num = config["Global"]["epoch_num"] |
| 227 | print_batch_step = config["Global"]["print_batch_step"] |
| 228 | eval_batch_step = config["Global"]["eval_batch_step"] |
| 229 | eval_batch_epoch = config["Global"].get("eval_batch_epoch", None) |
| 230 | profiler_options = config["profiler_options"] |
| 231 | print_mem_info = config["Global"].get("print_mem_info", True) |
| 232 | uniform_output_enabled = config["Global"].get("uniform_output_enabled", False) |
| 233 | |
| 234 | global_step = 0 |
| 235 | if "global_step" in pre_best_model_dict: |
| 236 | global_step = pre_best_model_dict["global_step"] |
| 237 | start_eval_step = 0 |
| 238 | if isinstance(eval_batch_step, list) and len(eval_batch_step) >= 2: |
| 239 | start_eval_step = eval_batch_step[0] if not eval_batch_epoch else 0 |
| 240 | eval_batch_step = ( |
| 241 | eval_batch_step[1] |
| 242 | if not eval_batch_epoch |
| 243 | else step_pre_epoch * eval_batch_epoch |
| 244 | ) |
| 245 | if len(valid_dataloader) == 0: |
| 246 | logger.info( |
| 247 | "No Images in eval dataset, evaluation during training " |
| 248 | "will be disabled" |
| 249 | ) |
| 250 | start_eval_step = 1e111 |
| 251 | logger.info( |
| 252 | "During the training process, after the {}th iteration, " |
| 253 | "an evaluation is run every {} iterations".format( |
| 254 | start_eval_step, eval_batch_step |
| 255 | ) |
| 256 | ) |
| 257 | save_epoch_step = config["Global"]["save_epoch_step"] |
nothing calls this directly
no test coverage detected
searching dependent graphs…