Saves a checkpoint containing the model's state, the optimizer's state, and the scheduler's state at the end of the given epoch. This method is intended to be called at the end of each epoch to save the training progress. Args: epoch (int): The epoch num
(
self,
epoch,
step=None,
model=None,
optim=None,
scheduler=None,
scaler=None,
step_in_epoch=None,
**kwargs,
)
| 146 | ) |
| 147 | |
| 148 | def save_checkpoint( |
| 149 | self, |
| 150 | epoch, |
| 151 | step=None, |
| 152 | model=None, |
| 153 | optim=None, |
| 154 | scheduler=None, |
| 155 | scaler=None, |
| 156 | step_in_epoch=None, |
| 157 | **kwargs, |
| 158 | ): |
| 159 | """ |
| 160 | Saves a checkpoint containing the model's state, the optimizer's state, |
| 161 | and the scheduler's state at the end of the given epoch. This method is |
| 162 | intended to be called at the end of each epoch to save the training progress. |
| 163 | |
| 164 | Args: |
| 165 | epoch (int): The epoch number at which the checkpoint is being saved. |
| 166 | """ |
| 167 | |
| 168 | step_in_epoch = None if step is None else step_in_epoch |
| 169 | if self.rank == 0: |
| 170 | logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n") |
| 171 | # self.step_or_epoch += 1 |
| 172 | state = { |
| 173 | "epoch": epoch, |
| 174 | "step": step, |
| 175 | "total_step": self.batch_total, |
| 176 | "state_dict": model.state_dict(), |
| 177 | "optimizer": optim.state_dict(), |
| 178 | "scheduler": scheduler.state_dict(), |
| 179 | "saved_ckpts": self.saved_ckpts, |
| 180 | "val_acc_step_or_epoch": self.val_acc_step_or_epoch, |
| 181 | "val_loss_step_or_epoch": self.val_loss_step_or_epoch, |
| 182 | "best_step_or_epoch": self.best_step_or_epoch, |
| 183 | "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type, |
| 184 | "step_in_epoch": step_in_epoch, |
| 185 | "data_split_i": kwargs.get("data_split_i", 0), |
| 186 | "data_split_num": kwargs.get("data_split_num", 1), |
| 187 | "batch_total": self.batch_total, |
| 188 | "train_loss_avg": kwargs.get("train_loss_avg", 0), |
| 189 | "train_acc_avg": kwargs.get("train_acc_avg", 0), |
| 190 | } |
| 191 | step = step_in_epoch |
| 192 | if hasattr(model, "module"): |
| 193 | state["state_dict"] = model.module.state_dict() |
| 194 | |
| 195 | if scaler: |
| 196 | state["scaler_state"] = scaler.state_dict() |
| 197 | |
| 198 | # Create output directory if it does not exist |
| 199 | os.makedirs(self.output_dir, exist_ok=True) |
| 200 | if step is None: |
| 201 | ckpt_name = f"model.pt.ep{epoch}" |
| 202 | else: |
| 203 | ckpt_name = f"model.pt.ep{epoch}.{step}" |
| 204 | filename = os.path.join(self.output_dir, ckpt_name) |
| 205 | torch.save(state, filename) |
no test coverage detected