MCPcopy
hub / github.com/modelscope/FunASR / save_checkpoint

Method save_checkpoint

funasr/train_utils/trainer.py:148–263  ·  view source on GitHub ↗

Saves a checkpoint containing the model's state, the optimizer's state, and the scheduler's state at the end of the given epoch. This method is intended to be called at the end of each epoch to save the training progress. Args: epoch (int): The epoch num

(
        self,
        epoch,
        step=None,
        model=None,
        optim=None,
        scheduler=None,
        scaler=None,
        step_in_epoch=None,
        **kwargs,
    )

Source from the content-addressed store, hash-verified

146 )
147
148 def save_checkpoint(
149 self,
150 epoch,
151 step=None,
152 model=None,
153 optim=None,
154 scheduler=None,
155 scaler=None,
156 step_in_epoch=None,
157 **kwargs,
158 ):
159 """
160 Saves a checkpoint containing the model's state, the optimizer's state,
161 and the scheduler's state at the end of the given epoch. This method is
162 intended to be called at the end of each epoch to save the training progress.
163
164 Args:
165 epoch (int): The epoch number at which the checkpoint is being saved.
166 """
167
168 step_in_epoch = None if step is None else step_in_epoch
169 if self.rank == 0:
170 logging.info(f"Save checkpoint: {epoch}, rank: {self.local_rank}\n")
171 # self.step_or_epoch += 1
172 state = {
173 "epoch": epoch,
174 "step": step,
175 "total_step": self.batch_total,
176 "state_dict": model.state_dict(),
177 "optimizer": optim.state_dict(),
178 "scheduler": scheduler.state_dict(),
179 "saved_ckpts": self.saved_ckpts,
180 "val_acc_step_or_epoch": self.val_acc_step_or_epoch,
181 "val_loss_step_or_epoch": self.val_loss_step_or_epoch,
182 "best_step_or_epoch": self.best_step_or_epoch,
183 "avg_keep_nbest_models_type": self.avg_keep_nbest_models_type,
184 "step_in_epoch": step_in_epoch,
185 "data_split_i": kwargs.get("data_split_i", 0),
186 "data_split_num": kwargs.get("data_split_num", 1),
187 "batch_total": self.batch_total,
188 "train_loss_avg": kwargs.get("train_loss_avg", 0),
189 "train_acc_avg": kwargs.get("train_acc_avg", 0),
190 }
191 step = step_in_epoch
192 if hasattr(model, "module"):
193 state["state_dict"] = model.module.state_dict()
194
195 if scaler:
196 state["scaler_state"] = scaler.state_dict()
197
198 # Create output directory if it does not exist
199 os.makedirs(self.output_dir, exist_ok=True)
200 if step is None:
201 ckpt_name = f"model.pt.ep{epoch}"
202 else:
203 ckpt_name = f"model.pt.ep{epoch}.{step}"
204 filename = os.path.join(self.output_dir, ckpt_name)
205 torch.save(state, filename)

Callers 3

train_epochMethod · 0.95
mainFunction · 0.95
mainFunction · 0.95

Calls 1

state_dictMethod · 0.45

Tested by

no test coverage detected