| 525 | |
| 526 | |
| 527 | class CUDACallback(Callback): |
| 528 | # see https://github.com/SeanNaren/minGPT/blob/master/mingpt/callback.py |
| 529 | |
| 530 | def on_train_start(self, trainer, pl_module): |
| 531 | rank_zero_info("Training is starting") |
| 532 | |
| 533 | # the method is called at the end of each training epoch |
| 534 | def on_train_end(self, trainer, pl_module): |
| 535 | rank_zero_info("Training is ending") |
| 536 | |
| 537 | def on_train_epoch_start(self, trainer, pl_module): |
| 538 | # Reset the memory use counter |
| 539 | torch.cuda.reset_peak_memory_stats(trainer.strategy.root_device.index) |
| 540 | torch.cuda.synchronize(trainer.strategy.root_device.index) |
| 541 | self.start_time = time.time() |
| 542 | |
| 543 | def on_train_epoch_end(self, trainer, pl_module): |
| 544 | torch.cuda.synchronize(trainer.strategy.root_device.index) |
| 545 | max_memory = torch.cuda.max_memory_allocated(trainer.strategy.root_device.index) / 2**20 |
| 546 | epoch_time = time.time() - self.start_time |
| 547 | |
| 548 | try: |
| 549 | max_memory = trainer.strategy.reduce(max_memory) |
| 550 | epoch_time = trainer.strategy.reduce(epoch_time) |
| 551 | |
| 552 | rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") |
| 553 | rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") |
| 554 | except AttributeError: |
| 555 | pass |
| 556 | |
| 557 | |
| 558 | if __name__ == "__main__": |
no outgoing calls
no test coverage detected
searching dependent graphs…