(args, model)
| 378 | return model, optimizer |
| 379 | |
| 380 | def load_checkpoint(args, model): |
| 381 | global global_step |
| 382 | global global_data_samples |
| 383 | global last_global_step_from_restore |
| 384 | |
| 385 | config = args.config |
| 386 | logger = args.logger |
| 387 | |
| 388 | logger.info( |
| 389 | f"Restoring previous training checkpoint from PATH={args.load_training_checkpoint}, CKPT_ID={args.load_checkpoint_id}") |
| 390 | start_epoch, global_step, global_data_samples = load_training_checkpoint( |
| 391 | args=args, |
| 392 | model=model, |
| 393 | PATH=args.load_training_checkpoint, |
| 394 | ckpt_id=args.load_checkpoint_id) |
| 395 | logger.info( |
| 396 | f"The model is loaded from last checkpoint at epoch {start_epoch} when the global steps were at {global_step} and global data samples at {global_data_samples}") |
| 397 | |
| 398 | if args.rewarmup: |
| 399 | logger.info( |
| 400 | f"Rewarmup learning rate with last_global_step_from_restore = {global_step}") |
| 401 | last_global_step_from_restore = global_step |
| 402 | |
| 403 | lr_this_step = config["training"]["learning_rate"] * warmup_linear_decay_exp(global_step, |
| 404 | config["training"]["decay_rate"], |
| 405 | config["training"]["decay_step"], |
| 406 | config["training"]["total_training_steps"], |
| 407 | config["training"]["warmup_proportion"]) |
| 408 | logger.info(f"Restart training with lr = {lr_this_step}") |
| 409 | |
| 410 | # Run validation for checkpoint before training |
| 411 | if not args.finetune and args.max_seq_length == 512: |
| 412 | logger.info(f"Validation Loss of Checkpoint {start_epoch} before pretraining") |
| 413 | logger.info(f"TRAIN MICRO BATCH SIZE PER GPU: {args.train_micro_batch_size_per_gpu}") |
| 414 | index = start_epoch - 1 if start_epoch > 0 else start_epoch |
| 415 | pretrain_validation(args, index, model) |
| 416 | |
| 417 | return start_epoch |
| 418 | |
| 419 | def run(args, model, optimizer, start_epoch): |
| 420 | global global_step |
no test coverage detected