(self)
| 135 | return {'loss': avg_loss, 'accuracy': accuracy} |
| 136 | |
| 137 | def train(self): |
| 138 | self.logger.info("Starting training...") |
| 139 | start_time = time.time() |
| 140 | |
| 141 | for epoch in range(self.config.epochs): |
| 142 | self.current_epoch = epoch |
| 143 | epoch_start = time.time() |
| 144 | |
| 145 | train_metrics = self.train_epoch() |
| 146 | val_metrics = self.validate() |
| 147 | |
| 148 | if self.scheduler is not None: |
| 149 | if isinstance(self.scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): |
| 150 | self.scheduler.step(val_metrics.get('loss', train_metrics['loss'])) |
| 151 | else: |
| 152 | self.scheduler.step() |
| 153 | |
| 154 | current_lr = self.optimizer.param_groups[0]['lr'] |
| 155 | epoch_time = time.time() - epoch_start |
| 156 | |
| 157 | self.logger.info( |
| 158 | f"Epoch {epoch + 1}/{self.config.epochs} | " |
| 159 | f"Train Loss: {train_metrics['loss']:.4f} | " |
| 160 | f"Train Acc: {train_metrics['accuracy']:.2f}% | " |
| 161 | f"Val Loss: {val_metrics.get('loss', 0):.4f} | " |
| 162 | f"Val Acc: {val_metrics.get('accuracy', 0):.2f}% | " |
| 163 | f"LR: {current_lr:.6f} | " |
| 164 | f"Time: {epoch_time:.2f}s" |
| 165 | ) |
| 166 | |
| 167 | self.metrics.update({ |
| 168 | 'epoch': epoch + 1, |
| 169 | 'train_loss': train_metrics['loss'], |
| 170 | 'train_acc': train_metrics['accuracy'], |
| 171 | 'val_loss': val_metrics.get('loss', 0), |
| 172 | 'val_acc': val_metrics.get('accuracy', 0), |
| 173 | 'lr': current_lr, |
| 174 | 'time': epoch_time |
| 175 | }) |
| 176 | |
| 177 | if (epoch + 1) % self.config.checkpoint_freq == 0: |
| 178 | self.save_checkpoint(f'checkpoint_epoch_{epoch + 1}.pt') |
| 179 | |
| 180 | if val_metrics and val_metrics['loss'] < self.best_val_loss: |
| 181 | self.best_val_loss = val_metrics['loss'] |
| 182 | self.save_checkpoint('best_model.pt') |
| 183 | self.logger.info(f"New best model saved with val_loss: {self.best_val_loss:.4f}") |
| 184 | |
| 185 | total_time = time.time() - start_time |
| 186 | self.logger.info(f"Training completed in {total_time / 3600:.2f} hours") |
| 187 | |
| 188 | self.save_checkpoint('final_model.pt') |
| 189 | self.metrics.save(os.path.join(self.config.log_dir, 'metrics.json')) |
| 190 | |
| 191 | def save_checkpoint(self, filename: str): |
| 192 | checkpoint_path = os.path.join(self.config.model_dir, filename) |
no test coverage detected