Attempt to restore the training state of the last ckpt. Args: lr_scheduler (_LRScheduler): lr_scheduler object. optimizer (Optimizer): optimizer object. lr (float): learning rate. train_state (dict): traing states. train_dl (DataLo
(self, lr_scheduler, optimizer, lr, train_state, train_dl)
| 531 | load_model_checkpoint(folder=model_load_path, model=self.model) |
| 532 | |
| 533 | def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl): |
| 534 | """Attempt to restore the training state of the last ckpt. |
| 535 | |
| 536 | Args: |
| 537 | lr_scheduler (_LRScheduler): lr_scheduler object. |
| 538 | optimizer (Optimizer): optimizer object. |
| 539 | lr (float): learning rate. |
| 540 | train_state (dict): traing states. |
| 541 | train_dl (DataLoader): traning dataloader object |
| 542 | """ |
| 543 | if self.load_ckpt_folder is not None: |
| 544 | # load optimzier states. |
| 545 | if self.load_optimizer: |
| 546 | load_optimizer_checkpoint(self.load_ckpt_folder, optimizer) |
| 547 | # load lr scheduler states. |
| 548 | load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state) |
| 549 | # load training states. |
| 550 | load_context(self.load_ckpt_folder, train_dl, train_state) |
| 551 | # load dataloader sampler states. |
| 552 | if hasattr(train_state, "batch_sampler") and not isinstance( |
| 553 | train_state.batch_sampler, torch.utils.data.sampler.BatchSampler |
| 554 | ): |
| 555 | load_sampler(self.load_ckpt_folder, train_dl.batch_sampler) |
| 556 | if hasattr(train_state, "data_state_dict"): |
| 557 | train_dl.dataset.load_state_dict( |
| 558 | llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder |
| 559 | ) |
| 560 | self.optimizer = optimizer |
| 561 | self.lr_scheduler = lr_scheduler |
| 562 | |
| 563 | def save_checkpoint( |
| 564 | self, |
no test coverage detected