MCPcopy
hub / github.com/hpcaitech/ColossalAI / CUDACallback

Class CUDACallback

examples/images/diffusion/main.py:527–555  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

525
526
527class CUDACallback(Callback):
528 # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py
529
530 def on_train_start(self, trainer, pl_module):
531 rank_zero_info("Training is starting")
532
533 # the method is called at the end of each training epoch
534 def on_train_end(self, trainer, pl_module):
535 rank_zero_info("Training is ending")
536
537 def on_train_epoch_start(self, trainer, pl_module):
538 # Reset the memory use counter
539 torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index)
540 torch.cuda.synchronize(trainer.strategy.root_device.index)
541 self.start_time = time.time()
542
543 def on_train_epoch_end(self, trainer, pl_module):
544 torch.cuda.synchronize(trainer.strategy.root_device.index)
545 max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20
546 epoch_time = time.time() - self.start_time
547
548 try:
549 max_memory = trainer.strategy.reduce(max_memory)
550 epoch_time = trainer.strategy.reduce(epoch_time)
551
552 rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
553 rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
554 except AttributeError:
555 pass
556
557
558if __name__ == "__main__":

Callers 1

main.pyFile · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…