(args, model, optimizer, start_epoch)
| 417 | return start_epoch |
| 418 | |
| 419 | def run(args, model, optimizer, start_epoch): |
| 420 | global global_step |
| 421 | global global_data_samples |
| 422 | global last_global_step_from_restore |
| 423 | |
| 424 | config = args.config |
| 425 | logger = args.logger |
| 426 | |
| 427 | for index in range(start_epoch, config["training"]["num_epochs"]): |
| 428 | logger.info(f"Training Epoch: {index + 1}") |
| 429 | pre = time.time() |
| 430 | train(args, index, model, optimizer) |
| 431 | logger.info( |
| 432 | f"Saving a checkpointing of the model for epoch: {index+1}") |
| 433 | checkpoint_model(PATH=args.saved_model_path, |
| 434 | ckpt_id='epoch{}_step{}'.format(index + 1, global_step), |
| 435 | model=model, |
| 436 | epoch=index+1, |
| 437 | last_global_step=global_step, |
| 438 | last_global_data_samples=global_data_samples) |
| 439 | |
| 440 | post = time.time() |
| 441 | logger.info(f"Time for shard {index + 1}: {post-pre} seconds") |
| 442 | |
| 443 | current_global_step = global_step - last_global_step_from_restore |
| 444 | if is_time_to_exit(args=args, global_steps=current_global_step): |
| 445 | print(f'Warning: Early training termination due to max steps limit, epoch={index+1}, global_step={current_global_step}') |
| 446 | break |
| 447 | |
| 448 | |
| 449 | def main(): |
no test coverage detected