(self, filename: str)
| 189 | self.metrics.save(os.path.join(self.config.log_dir, 'metrics.json')) |
| 190 | |
| 191 | def save_checkpoint(self, filename: str): |
| 192 | checkpoint_path = os.path.join(self.config.model_dir, filename) |
| 193 | |
| 194 | checkpoint = { |
| 195 | 'epoch': self.current_epoch, |
| 196 | 'global_step': self.global_step, |
| 197 | 'model_state_dict': self.model.state_dict(), |
| 198 | 'optimizer_state_dict': self.optimizer.state_dict(), |
| 199 | 'best_val_loss': self.best_val_loss, |
| 200 | 'config': self.config, |
| 201 | } |
| 202 | |
| 203 | if self.scheduler is not None: |
| 204 | checkpoint['scheduler_state_dict'] = self.scheduler.state_dict() |
| 205 | |
| 206 | if self.scaler is not None: |
| 207 | checkpoint['scaler_state_dict'] = self.scaler.state_dict() |
| 208 | |
| 209 | torch.save(checkpoint, checkpoint_path) |
| 210 | self.logger.info(f"Checkpoint saved: {checkpoint_path}") |
| 211 | |
| 212 | def load_checkpoint(self, checkpoint_path: str): |
| 213 | self.logger.info(f"Loading checkpoint: {checkpoint_path}") |
no test coverage detected