| 48 | self.logger.info(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}") |
| 49 | |
| 50 | def train_epoch(self) -> Dict[str, float]: |
| 51 | self.model.train() |
| 52 | epoch_loss = 0.0 |
| 53 | correct = 0 |
| 54 | total = 0 |
| 55 | |
| 56 | pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1}/{self.config.epochs}") |
| 57 | |
| 58 | for batch_idx, (inputs, targets) in enumerate(pbar): |
| 59 | inputs = inputs.to(self.device, non_blocking=True) |
| 60 | targets = targets.to(self.device, non_blocking=True) |
| 61 | |
| 62 | self.optimizer.zero_grad(set_to_none=True) |
| 63 | |
| 64 | if self.scaler is not None: |
| 65 | with amp.autocast('cuda'): |
| 66 | outputs = self.model(inputs) |
| 67 | loss = self.criterion(outputs, targets) |
| 68 | |
| 69 | self.scaler.scale(loss).backward() |
| 70 | |
| 71 | if self.config.grad_clip > 0: |
| 72 | self.scaler.unscale_(self.optimizer) |
| 73 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) |
| 74 | |
| 75 | self.scaler.step(self.optimizer) |
| 76 | self.scaler.update() |
| 77 | else: |
| 78 | outputs = self.model(inputs) |
| 79 | loss = self.criterion(outputs, targets) |
| 80 | loss.backward() |
| 81 | |
| 82 | if self.config.grad_clip > 0: |
| 83 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.grad_clip) |
| 84 | |
| 85 | self.optimizer.step() |
| 86 | |
| 87 | epoch_loss += loss.item() |
| 88 | _, predicted = outputs.max(1) |
| 89 | total += targets.size(0) |
| 90 | correct += predicted.eq(targets).sum().item() |
| 91 | |
| 92 | self.global_step += 1 |
| 93 | |
| 94 | if batch_idx % 10 == 0: |
| 95 | pbar.set_postfix({ |
| 96 | 'loss': f'{loss.item():.4f}', |
| 97 | 'acc': f'{100. * correct / total:.2f}%' |
| 98 | }) |
| 99 | |
| 100 | avg_loss = epoch_loss / len(self.train_loader) |
| 101 | accuracy = 100. * correct / total |
| 102 | |
| 103 | return {'loss': avg_loss, 'accuracy': accuracy} |
| 104 | |
| 105 | def validate(self) -> Dict[str, float]: |
| 106 | if self.val_loader is None: |