MCPcopy
hub / github.com/OpenMotionLab/MotionGPT / progressLogger

Class progressLogger

mGPT/callback.py:152–200  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

150 return items
151
152class progressLogger(Callback):
153 def __init__(self,
154 logger,
155 metric_monitor: dict,
156 precision: int = 3,
157 log_every_n_steps: int = 1):
158 # Metric to monitor
159 self.logger = logger
160 self.metric_monitor = metric_monitor
161 self.precision = precision
162 self.log_every_n_steps = log_every_n_steps
163
164 def on_train_start(self, trainer: Trainer, pl_module: LightningModule,
165 **kwargs) -> None:
166 self.logger.info("Training started")
167
168 def on_train_end(self, trainer: Trainer, pl_module: LightningModule,
169 **kwargs) -> None:
170 self.logger.info("Training done")
171
172 def on_validation_epoch_end(self, trainer: Trainer,
173 pl_module: LightningModule, **kwargs) -> None:
174 if trainer.sanity_checking:
175 self.logger.info("Sanity checking ok.")
176
177 def on_train_epoch_end(self,
178 trainer: Trainer,
179 pl_module: LightningModule,
180 padding=False,
181 **kwargs) -> None:
182 metric_format = f"{{:.{self.precision}e}}"
183 line = f"Epoch {trainer.current_epoch}"
184 if padding:
185 line = f"{line:>{len('Epoch xxxx')}}" # Right padding
186
187 if trainer.current_epoch % self.log_every_n_steps == 0:
188 metrics_str = []
189
190 losses_dict = trainer.callback_metrics
191 for metric_name, dico_name in self.metric_monitor.items():
192 if dico_name in losses_dict:
193 metric = losses_dict[dico_name].item()
194 metric = metric_format.format(metric)
195 metric = f"{metric_name} {metric}"
196 metrics_str.append(metric)
197
198 line = line + ": " + " ".join(metrics_str)
199
200 self.logger.info(line)

Callers 1

getCheckpointCallbackFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected