MCPcopy
hub / github.com/hpcaitech/ColossalAI / SetupCallback

Class SetupCallback

examples/images/diffusion/main.py:318–379  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

316
317
318class SetupCallback(Callback):
319 # Initialize the callback with the necessary parameters
320
321 def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
322 super().__init__()
323 self.resume = resume
324 self.now = now
325 self.logdir = logdir
326 self.ckptdir = ckptdir
327 self.cfgdir = cfgdir
328 self.config = config
329 self.lightning_config = lightning_config
330
331 # Save a checkpoint if training is interrupted with keyboard interrupt
332 def on_keyboard_interrupt(self, trainer, pl_module):
333 if trainer.global_rank == 0:
334 print("Summoning checkpoint.")
335 ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
336 trainer.save_checkpoint(ckpt_path)
337
338 # Create necessary directories and save configuration files before training starts
339 # def on_pretrain_routine_start(self, trainer, pl_module):
340 def on_fit_start(self, trainer, pl_module):
341 if trainer.global_rank == 0:
342 # Create logdirs and save configs
343 os.makedirs(self.logdir, exist_ok=True)
344 os.makedirs(self.ckptdir, exist_ok=True)
345 os.makedirs(self.cfgdir, exist_ok=True)
346
347 # Create trainstep checkpoint directory if necessary
348 if "callbacks" in self.lightning_config:
349 if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]:
350 os.makedirs(os.path.join(self.ckptdir, "trainstep_checkpoints"), exist_ok=True)
351 print("Project config")
352 print(OmegaConf.to_yaml(self.config))
353 OmegaConf.save(self.config, os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
354
355 # Save project config and lightning config as YAML files
356 print("Lightning config")
357 print(OmegaConf.to_yaml(self.lightning_config))
358 OmegaConf.save(
359 OmegaConf.create({"lightning": self.lightning_config}),
360 os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
361 )
362
363 # Remove log directory if resuming training and directory already exists
364 else:
365 # ModelCheckpoint callback created log directory --- remove it
366 if not self.resume and os.path.exists(self.logdir):
367 dst, name = os.path.split(self.logdir)
368 dst = os.path.join(dst, "child_runs", name)
369 os.makedirs(os.path.split(dst)[0], exist_ok=True)
370 try:
371 os.rename(self.logdir, dst)
372 except FileNotFoundError:
373 pass
374
375 # def on_fit_end(self, trainer, pl_module):

Callers 1

main.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…