(
self,
model,
metrics,
criterions,
layer_lrs,
optimizer,
scheduler,
)
| 25 | """ |
| 26 | |
| 27 | def __init__( |
| 28 | self, |
| 29 | model, |
| 30 | metrics, |
| 31 | criterions, |
| 32 | layer_lrs, |
| 33 | optimizer, |
| 34 | scheduler, |
| 35 | ): |
| 36 | super().__init__() |
| 37 | |
| 38 | # this line allows to access init params with 'self.hparams' attribute |
| 39 | # it also ensures init params will be stored in ckpt |
| 40 | self.save_hyperparameters(logger=False) |
| 41 | |
| 42 | self.model: LTPModule = instantiate(model) |
| 43 | self.layer_lrs = instantiate(layer_lrs) |
| 44 | self.optimizer = instantiate(optimizer) |
| 45 | self.scheduler = instantiate(scheduler) |
| 46 | |
| 47 | # loss function |
| 48 | criterions = instantiate(criterions) |
| 49 | self.task_list = list(criterions.keys()) |
| 50 | self.criterions = ModuleDict(criterions) |
| 51 | |
| 52 | # use separate metric instance for train, val and test step |
| 53 | # to ensure a proper reduction over the epoch |
| 54 | metrics = instantiate(metrics) |
| 55 | |
| 56 | # must use module dict |
| 57 | metrics = ModuleDict( |
| 58 | {task: MetricCollection(metric, prefix=f"{task}/") for task, metric in metrics.items()} |
| 59 | ) |
| 60 | self.train_metrics = metrics |
| 61 | self.val_metrics = deepcopy(metrics) |
| 62 | self.test_metrics = deepcopy(metrics) |
| 63 | |
| 64 | self.mean_metrics = MeanMetric() |
| 65 | |
| 66 | def forward( |
| 67 | self, |
nothing calls this directly
no test coverage detected