MCPcopy
hub / github.com/deepspeedai/DeepSpeedExamples / load_checkpoint

Function load_checkpoint

bing_bert/deepspeed_train.py:380–417  ·  view source on GitHub ↗
(args, model)

Source from the content-addressed store, hash-verified

378 return model, optimizer
379
380def load_checkpoint(args, model):
381 global global_step
382 global global_data_samples
383 global last_global_step_from_restore
384
385 config = args.config
386 logger = args.logger
387
388 logger.info(
389 f"Restoring previous training checkpoint from PATH={args.load_training_checkpoint}, CKPT_ID={args.load_checkpoint_id}")
390 start_epoch, global_step, global_data_samples = load_training_checkpoint(
391 args=args,
392 model=model,
393 PATH=args.load_training_checkpoint,
394 ckpt_id=args.load_checkpoint_id)
395 logger.info(
396 f"The model is loaded from last checkpoint at epoch {start_epoch} when the global steps were at {global_step} and global data samples at {global_data_samples}")
397
398 if args.rewarmup:
399 logger.info(
400 f"Rewarmup learning rate with last_global_step_from_restore = {global_step}")
401 last_global_step_from_restore = global_step
402
403 lr_this_step = config["training"]["learning_rate"] * warmup_linear_decay_exp(global_step,
404 config["training"]["decay_rate"],
405 config["training"]["decay_step"],
406 config["training"]["total_training_steps"],
407 config["training"]["warmup_proportion"])
408 logger.info(f"Restart training with lr = {lr_this_step}")
409
410 # Run validation for checkpoint before training
411 if not args.finetune and args.max_seq_length == 512:
412 logger.info(f"Validation Loss of Checkpoint {start_epoch} before pretraining")
413 logger.info(f"TRAIN MICRO BATCH SIZE PER GPU: {args.train_micro_batch_size_per_gpu}")
414 index = start_epoch - 1 if start_epoch > 0 else start_epoch
415 pretrain_validation(args, index, model)
416
417 return start_epoch
418
419def run(args, model, optimizer, start_epoch):
420 global global_step

Callers 1

mainFunction · 0.70

Calls 4

warmup_linear_decay_expFunction · 0.90
load_training_checkpointFunction · 0.85
pretrain_validationFunction · 0.85
infoMethod · 0.80

Tested by

no test coverage detected