| 316 | |
| 317 | |
| 318 | class SetupCallback(Callback): |
| 319 | # Initialize the callback with the necessary parameters |
| 320 | |
| 321 | def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config): |
| 322 | super().__init__() |
| 323 | self.resume = resume |
| 324 | self.now = now |
| 325 | self.logdir = logdir |
| 326 | self.ckptdir = ckptdir |
| 327 | self.cfgdir = cfgdir |
| 328 | self.config = config |
| 329 | self.lightning_config = lightning_config |
| 330 | |
| 331 | # Save a checkpoint if training is interrupted with keyboard interrupt |
| 332 | def on_keyboard_interrupt(self, trainer, pl_module): |
| 333 | if trainer.global_rank == 0: |
| 334 | print("Summoning checkpoint.") |
| 335 | ckpt_path = os.path.join(self.ckptdir, "last.ckpt") |
| 336 | trainer.save_checkpoint(ckpt_path) |
| 337 | |
| 338 | # Create necessary directories and save configuration files before training starts |
| 339 | # def on_pretrain_routine_start(self, trainer, pl_module): |
| 340 | def on_fit_start(self, trainer, pl_module): |
| 341 | if trainer.global_rank == 0: |
| 342 | # Create logdirs and save configs |
| 343 | os.makedirs(self.logdir, exist_ok=True) |
| 344 | os.makedirs(self.ckptdir, exist_ok=True) |
| 345 | os.makedirs(self.cfgdir, exist_ok=True) |
| 346 | |
| 347 | # Create trainstep checkpoint directory if necessary |
| 348 | if "callbacks" in self.lightning_config: |
| 349 | if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]: |
| 350 | os.makedirs(os.path.join(self.ckptdir, "trainstep_checkpoints"), exist_ok=True) |
| 351 | print("Project config") |
| 352 | print(OmegaConf.to_yaml(self.config)) |
| 353 | OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) |
| 354 | |
| 355 | # Save project config and lightning config as YAML files |
| 356 | print("Lightning config") |
| 357 | print(OmegaConf.to_yaml(self.lightning_config)) |
| 358 | OmegaConf.save( |
| 359 | OmegaConf.create({"lightning": self.lightning_config}), |
| 360 | os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)), |
| 361 | ) |
| 362 | |
| 363 | # Remove log directory if resuming training and directory already exists |
| 364 | else: |
| 365 | # ModelCheckpoint callback created log directory --- remove it |
| 366 | if not self.resume and os.path.exists(self.logdir): |
| 367 | dst, name = os.path.split(self.logdir) |
| 368 | dst = os.path.join(dst, "child_runs", name) |
| 369 | os.makedirs(os.path.split(dst)[0], exist_ok=True) |
| 370 | try: |
| 371 | os.rename(self.logdir, dst) |
| 372 | except FileNotFoundError: |
| 373 | pass |
| 374 | |
| 375 | # def on_fit_end(self, trainer, pl_module): |
no outgoing calls
no test coverage detected
searching dependent graphs…