(self, batch, batch_idx)
| 32 | return self.layer(x) |
| 33 | |
| 34 | def training_step(self, batch, batch_idx): |
| 35 | loss = self(batch).sum() |
| 36 | self.log("train_loss", loss) |
| 37 | return {"loss": loss} |
| 38 | |
| 39 | def validation_step(self, batch, batch_idx): |
| 40 | loss = self(batch).sum() |