| 116 | update_ema(self.ema, self.model) |
| 117 | |
| 118 | def on_save_checkpoint(self, checkpoint): |
| 119 | super().on_save_checkpoint(checkpoint) |
| 120 | checkpoint_dir = self.trainer.checkpoint_callback.dirpath |
| 121 | epoch = self.trainer.current_epoch |
| 122 | step = self.trainer.global_step |
| 123 | checkpoint = { |
| 124 | "model": self.model.state_dict(), |
| 125 | "ema": self.ema.state_dict(), |
| 126 | } |
| 127 | torch.save(checkpoint, f"{checkpoint_dir}/epoch{epoch}-step{step}.ckpt") |
| 128 | |
| 129 | def configure_optimizers(self): |
| 130 | self.lr_scheduler = get_scheduler( |