(self)
| 101 | return loss_dict |
| 102 | |
| 103 | def metrics_log_dict(self): |
| 104 | |
| 105 | # For TM2TMetrics MM |
| 106 | if self.trainer.datamodule.is_mm and "TM2TMetrics" in self.hparams.metrics_dict: |
| 107 | metrics_dicts = ['MMMetrics'] |
| 108 | else: |
| 109 | metrics_dicts = self.hparams.metrics_dict |
| 110 | |
| 111 | # Compute all metrics |
| 112 | metrics_log_dict = {} |
| 113 | for metric in metrics_dicts: |
| 114 | metrics_dict = getattr( |
| 115 | self.metrics, |
| 116 | metric).compute(sanity_flag=self.trainer.sanity_checking) |
| 117 | metrics_log_dict.update({ |
| 118 | f"Metrics/{metric}": value.item() |
| 119 | for metric, value in metrics_dict.items() |
| 120 | }) |
| 121 | |
| 122 | return metrics_log_dict |
| 123 | |
| 124 | def configure_optimizers(self): |
| 125 | # Optimizer |
no test coverage detected