| 126 | update_ema(self.ema, self.model) |
| 127 | |
| 128 | def on_save_checkpoint(self, checkpoint): |
| 129 | super().on_save_checkpoint(checkpoint) |
| 130 | checkpoint_dir = self.trainer.checkpoint_callback.dirpath |
| 131 | epoch = self.trainer.current_epoch |
| 132 | step = self.trainer.global_step |
| 133 | checkpoint = { |
| 134 | "model": self.model.state_dict(), |
| 135 | "ema": self.ema.state_dict(), |
| 136 | } |
| 137 | torch.save(checkpoint, f"{checkpoint_dir}/epoch{epoch}-step{step}.ckpt") |
| 138 | |
| 139 | def configure_optimizers(self): |
| 140 | self.lr_scheduler = get_scheduler( |