Perform a single model step on a batch of data. :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. :return: A tuple containing (in order): - A tensor of losses. - A tensor of predictions. - A tensor o
(
self, batch: Tuple[torch.Tensor, torch.Tensor]
)
| 93 | self.val_acc_best.reset() |
| 94 | |
| 95 | def model_step( |
| 96 | self, batch: Tuple[torch.Tensor, torch.Tensor] |
| 97 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 98 | """Perform a single model step on a batch of data. |
| 99 | |
| 100 | :param batch: A batch of data (a tuple) containing the input tensor of images and target labels. |
| 101 | |
| 102 | :return: A tuple containing (in order): |
| 103 | - A tensor of losses. |
| 104 | - A tensor of predictions. |
| 105 | - A tensor of target labels. |
| 106 | """ |
| 107 | x, y = batch |
| 108 | logits = self.forward(x) |
| 109 | loss = self.criterion(logits, y) |
| 110 | preds = torch.argmax(logits, dim=1) |
| 111 | return loss, preds, y |
| 112 | |
| 113 | def training_step( |
| 114 | self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int |