| 103 | return {'loss': avg_loss, 'accuracy': accuracy} |
| 104 | |
| 105 | def validate(self) -> Dict[str, float]: |
| 106 | if self.val_loader is None: |
| 107 | return {} |
| 108 | |
| 109 | self.model.eval() |
| 110 | val_loss = 0.0 |
| 111 | correct = 0 |
| 112 | total = 0 |
| 113 | |
| 114 | with torch.no_grad(): |
| 115 | for inputs, targets in tqdm(self.val_loader, desc="Validation"): |
| 116 | inputs = inputs.to(self.device, non_blocking=True) |
| 117 | targets = targets.to(self.device, non_blocking=True) |
| 118 | |
| 119 | if self.scaler is not None: |
| 120 | with amp.autocast('cuda'): |
| 121 | outputs = self.model(inputs) |
| 122 | loss = self.criterion(outputs, targets) |
| 123 | else: |
| 124 | outputs = self.model(inputs) |
| 125 | loss = self.criterion(outputs, targets) |
| 126 | |
| 127 | val_loss += loss.item() |
| 128 | _, predicted = outputs.max(1) |
| 129 | total += targets.size(0) |
| 130 | correct += predicted.eq(targets).sum().item() |
| 131 | |
| 132 | avg_loss = val_loss / len(self.val_loader) |
| 133 | accuracy = 100. * correct / total |
| 134 | |
| 135 | return {'loss': avg_loss, 'accuracy': accuracy} |
| 136 | |
| 137 | def train(self): |
| 138 | self.logger.info("Starting training...") |