(self, epoch)
| 50 | self.writer.reset('val') |
| 51 | |
| 52 | def train_epoch(self, epoch): |
| 53 | self.model.train() |
| 54 | |
| 55 | for batch_idx, input_tuple in enumerate(self.train_data_loader): |
| 56 | |
| 57 | self.optimizer.zero_grad() |
| 58 | |
| 59 | input_tensor, target = prepare_input(input_tuple=input_tuple, args=self.args) |
| 60 | input_tensor.requires_grad = True |
| 61 | output = self.model(input_tensor) |
| 62 | loss_dice, per_ch_score = self.criterion(output, target) |
| 63 | loss_dice.backward() |
| 64 | self.optimizer.step() |
| 65 | |
| 66 | self.writer.update_scores(batch_idx, loss_dice.item(), per_ch_score, 'train', |
| 67 | epoch * self.len_epoch + batch_idx) |
| 68 | |
| 69 | if (batch_idx + 1) % self.terminal_show_freq == 0: |
| 70 | partial_epoch = epoch + batch_idx / self.len_epoch - 1 |
| 71 | self.writer.display_terminal(partial_epoch, epoch, 'train') |
| 72 | |
| 73 | self.writer.display_terminal(self.len_epoch, epoch, mode='train', summary=True) |
| 74 | |
| 75 | def validate_epoch(self, epoch): |
| 76 | self.model.eval() |
no test coverage detected