MCPcopy
hub / github.com/InternLM/InternLM / try_resume_training

Method try_resume_training

internlm/utils/model_checkpoint.py:533–561  ·  view source on GitHub ↗

Attempt to restore the training state of the last ckpt. Args: lr_scheduler (_LRScheduler): lr_scheduler object. optimizer (Optimizer): optimizer object. lr (float): learning rate. train_state (dict): traing states. train_dl (DataLo

(self, lr_scheduler, optimizer, lr, train_state, train_dl)

Source from the content-addressed store, hash-verified

531 load_model_checkpoint(folder=model_load_path, model=self.model)
532
533 def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl):
534 """Attempt to restore the training state of the last ckpt.
535
536 Args:
537 lr_scheduler (_LRScheduler): lr_scheduler object.
538 optimizer (Optimizer): optimizer object.
539 lr (float): learning rate.
540 train_state (dict): traing states.
541 train_dl (DataLoader): traning dataloader object
542 """
543 if self.load_ckpt_folder is not None:
544 # load optimzier states.
545 if self.load_optimizer:
546 load_optimizer_checkpoint(self.load_ckpt_folder, optimizer)
547 # load lr scheduler states.
548 load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
549 # load training states.
550 load_context(self.load_ckpt_folder, train_dl, train_state)
551 # load dataloader sampler states.
552 if hasattr(train_state, "batch_sampler") and not isinstance(
553 train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
554 ):
555 load_sampler(self.load_ckpt_folder, train_dl.batch_sampler)
556 if hasattr(train_state, "data_state_dict"):
557 train_dl.dataset.load_state_dict(
558 llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder
559 )
560 self.optimizer = optimizer
561 self.lr_scheduler = lr_scheduler
562
563 def save_checkpoint(
564 self,

Callers 1

mainFunction · 0.95

Calls 6

llm_loadFunction · 0.90
load_schedulerFunction · 0.85
load_contextFunction · 0.85
load_samplerFunction · 0.85
load_state_dictMethod · 0.45

Tested by

no test coverage detected