| 214 | self.train_optimizer.step() |
| 215 | |
| 216 | def test_epoch(self, data_loader): |
| 217 | self.dnn_model.eval() |
| 218 | |
| 219 | scores = [] |
| 220 | losses = [] |
| 221 | |
| 222 | for data, weight in data_loader: |
| 223 | feature, label = self._get_fl(data) |
| 224 | |
| 225 | with torch.no_grad(): |
| 226 | pred = self.dnn_model(feature.float()) |
| 227 | loss = self.loss_fn(pred, label, weight.to(self.device)) |
| 228 | losses.append(loss.item()) |
| 229 | |
| 230 | score = self.metric_fn(pred, label) |
| 231 | scores.append(score.item()) |
| 232 | |
| 233 | return np.mean(losses), np.mean(scores) |
| 234 | |
| 235 | def fit( |
| 236 | self, |