MCPcopy
hub / github.com/geekcomputers/Python / save_checkpoint

Method save_checkpoint

ML/src/python/neuralforge/trainer.py:191–210  ·  view source on GitHub ↗
(self, filename: str)

Source from the content-addressed store, hash-verified

189 self.metrics.save(os.path.join(self.config.log_dir, 'metrics.json'))
190
191 def save_checkpoint(self, filename: str):
192 checkpoint_path = os.path.join(self.config.model_dir, filename)
193
194 checkpoint = {
195 'epoch': self.current_epoch,
196 'global_step': self.global_step,
197 'model_state_dict': self.model.state_dict(),
198 'optimizer_state_dict': self.optimizer.state_dict(),
199 'best_val_loss': self.best_val_loss,
200 'config': self.config,
201 }
202
203 if self.scheduler is not None:
204 checkpoint['scheduler_state_dict'] = self.scheduler.state_dict()
205
206 if self.scaler is not None:
207 checkpoint['scaler_state_dict'] = self.scaler.state_dict()
208
209 torch.save(checkpoint, checkpoint_path)
210 self.logger.info(f"Checkpoint saved: {checkpoint_path}")
211
212 def load_checkpoint(self, checkpoint_path: str):
213 self.logger.info(f"Loading checkpoint: {checkpoint_path}")

Callers 1

trainMethod · 0.95

Calls 3

state_dictMethod · 0.80
infoMethod · 0.80
saveMethod · 0.45

Tested by

no test coverage detected