| 150 | return items |
| 151 | |
| 152 | class progressLogger(Callback): |
| 153 | def __init__(self, |
| 154 | logger, |
| 155 | metric_monitor: dict, |
| 156 | precision: int = 3, |
| 157 | log_every_n_steps: int = 1): |
| 158 | # Metric to monitor |
| 159 | self.logger = logger |
| 160 | self.metric_monitor = metric_monitor |
| 161 | self.precision = precision |
| 162 | self.log_every_n_steps = log_every_n_steps |
| 163 | |
| 164 | def on_train_start(self, trainer: Trainer, pl_module: LightningModule, |
| 165 | **kwargs) -> None: |
| 166 | self.logger.info("Training started") |
| 167 | |
| 168 | def on_train_end(self, trainer: Trainer, pl_module: LightningModule, |
| 169 | **kwargs) -> None: |
| 170 | self.logger.info("Training done") |
| 171 | |
| 172 | def on_validation_epoch_end(self, trainer: Trainer, |
| 173 | pl_module: LightningModule, **kwargs) -> None: |
| 174 | if trainer.sanity_checking: |
| 175 | self.logger.info("Sanity checking ok.") |
| 176 | |
| 177 | def on_train_epoch_end(self, |
| 178 | trainer: Trainer, |
| 179 | pl_module: LightningModule, |
| 180 | padding=False, |
| 181 | **kwargs) -> None: |
| 182 | metric_format = f"{{:.{self.precision}e}}" |
| 183 | line = f"Epoch {trainer.current_epoch}" |
| 184 | if padding: |
| 185 | line = f"{line:>{len('Epoch xxxx')}}" # Right padding |
| 186 | |
| 187 | if trainer.current_epoch % self.log_every_n_steps == 0: |
| 188 | metrics_str = [] |
| 189 | |
| 190 | losses_dict = trainer.callback_metrics |
| 191 | for metric_name, dico_name in self.metric_monitor.items(): |
| 192 | if dico_name in losses_dict: |
| 193 | metric = losses_dict[dico_name].item() |
| 194 | metric = metric_format.format(metric) |
| 195 | metric = f"{metric_name} {metric}" |
| 196 | metrics_str.append(metric) |
| 197 | |
| 198 | line = line + ": " + " ".join(metrics_str) |
| 199 | |
| 200 | self.logger.info(line) |
no outgoing calls
no test coverage detected