| 210 | self.logger.info(f"Checkpoint saved: {checkpoint_path}") |
| 211 | |
| 212 | def load_checkpoint(self, checkpoint_path: str): |
| 213 | self.logger.info(f"Loading checkpoint: {checkpoint_path}") |
| 214 | checkpoint = torch.load(checkpoint_path, map_location=self.device) |
| 215 | |
| 216 | self.model.load_state_dict(checkpoint['model_state_dict']) |
| 217 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| 218 | self.current_epoch = checkpoint['epoch'] |
| 219 | self.global_step = checkpoint['global_step'] |
| 220 | self.best_val_loss = checkpoint['best_val_loss'] |
| 221 | |
| 222 | if self.scheduler is not None and 'scheduler_state_dict' in checkpoint: |
| 223 | self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
| 224 | |
| 225 | if self.scaler is not None and 'scaler_state_dict' in checkpoint: |
| 226 | self.scaler.load_state_dict(checkpoint['scaler_state_dict']) |
| 227 | |
| 228 | self.logger.info(f"Checkpoint loaded from epoch {self.current_epoch}") |
| 229 | |
| 230 | def test(self, test_loader: DataLoader) -> Dict[str, float]: |
| 231 | self.logger.info("Starting testing...") |