| 129 | [metric.reset() for metric in self.train_metrics.values()] |
| 130 | |
| 131 | def validation_step(self, batch: Any, batch_idx: int, dataloader_idx=0): |
| 132 | task_name = self.task_list[dataloader_idx] |
| 133 | loss, outputs = self.step(task_name, batch) |
| 134 | |
| 135 | # log val metrics |
| 136 | self.val_metrics[task_name].update(outputs, **batch) |
| 137 | self.log( |
| 138 | f"val/{task_name}/loss", |
| 139 | loss, |
| 140 | on_step=False, |
| 141 | on_epoch=True, |
| 142 | prog_bar=False, |
| 143 | add_dataloader_idx=False, |
| 144 | ) |
| 145 | |
| 146 | def validation_epoch_end(self, outputs: List[Any]): |
| 147 | for metric in self.val_metrics.values(): |