| 144 | ) |
| 145 | |
| 146 | def validation_epoch_end(self, outputs: List[Any]): |
| 147 | for metric in self.val_metrics.values(): |
| 148 | metric_dict = metric.compute() |
| 149 | for key, value in metric_dict.items(): |
| 150 | self.log( |
| 151 | f"val/{key}", |
| 152 | value, |
| 153 | on_step=False, |
| 154 | on_epoch=True, |
| 155 | prog_bar=True, |
| 156 | add_dataloader_idx=False, |
| 157 | ) |
| 158 | self.mean_metrics.update(value) |
| 159 | metric.reset() |
| 160 | self.log( |
| 161 | "val/mean_metric", |
| 162 | self.mean_metrics.compute(), |
| 163 | on_step=False, |
| 164 | on_epoch=True, |
| 165 | prog_bar=True, |
| 166 | ) |
| 167 | self.mean_metrics.reset() |
| 168 | |
| 169 | def test_step(self, batch: Any, batch_idx: int, dataloader_idx=0): |
| 170 | task_name = self.task_list[dataloader_idx] |