(self)
| 122 | return metrics_log_dict |
| 123 | |
| 124 | def configure_optimizers(self): |
| 125 | # Optimizer |
| 126 | optim_target = self.hparams.cfg.TRAIN.OPTIM.target |
| 127 | if len(optim_target.split('.')) == 1: |
| 128 | optim_target = 'torch.optim.' + optim_target |
| 129 | optimizer = get_obj_from_str(optim_target)( |
| 130 | params=self.parameters(), **self.hparams.cfg.TRAIN.OPTIM.params) |
| 131 | |
| 132 | # Scheduler |
| 133 | scheduler_target = self.hparams.cfg.TRAIN.LR_SCHEDULER.target |
| 134 | if len(scheduler_target.split('.')) == 1: |
| 135 | scheduler_target = 'torch.optim.lr_scheduler.' + scheduler_target |
| 136 | lr_scheduler = get_obj_from_str(scheduler_target)( |
| 137 | optimizer=optimizer, **self.hparams.cfg.TRAIN.LR_SCHEDULER.params) |
| 138 | |
| 139 | return {'optimizer': optimizer, 'lr_scheduler': lr_scheduler} |
| 140 | |
| 141 | def configure_metrics(self): |
| 142 | self.metrics = BaseMetrics(datamodule=self.datamodule, **self.hparams) |
nothing calls this directly
no test coverage detected