| 28 | |
| 29 | |
| 30 | class LatteTrainingModule(LightningModule): |
| 31 | def __init__(self, args, logger: logging.Logger): |
| 32 | super(LatteTrainingModule, self).__init__() |
| 33 | self.args = args |
| 34 | self.logging = logger |
| 35 | self.model = get_models(args) |
| 36 | self.ema = deepcopy(self.model) |
| 37 | requires_grad(self.ema, False) |
| 38 | |
| 39 | # Load pretrained model if specified |
| 40 | if args.pretrained: |
| 41 | # Load old checkpoint, only load EMA |
| 42 | self._load_pretrained_parameters(args) |
| 43 | self.logging.info(f"Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}") |
| 44 | |
| 45 | self.diffusion = create_diffusion(timestep_respacing="") |
| 46 | self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae") |
| 47 | self.opt = torch.optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=0) |
| 48 | self.lr_scheduler = None |
| 49 | |
| 50 | # Freeze VAE |
| 51 | self.vae.requires_grad_(False) |
| 52 | |
| 53 | update_ema(self.ema, self.model, decay=0) # Ensure EMA is initialized with synced weights |
| 54 | self.model.train() # important! This enables embedding dropout for classifier-free guidance |
| 55 | self.ema.eval() |
| 56 | |
| 57 | def _load_pretrained_parameters(self, args): |
| 58 | checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage) |
| 59 | if "ema" in checkpoint: # supports checkpoints from train.py |
| 60 | self.logging.info("Using ema ckpt!") |
| 61 | checkpoint = checkpoint["ema"] |
| 62 | |
| 63 | model_dict = self.model.state_dict() |
| 64 | # 1. filter out unnecessary keys |
| 65 | pretrained_dict = {} |
| 66 | for k, v in checkpoint.items(): |
| 67 | if k in model_dict: |
| 68 | pretrained_dict[k] = v |
| 69 | else: |
| 70 | self.logging.info("Ignoring: {}".format(k)) |
| 71 | self.logging.info(f"Successfully Load {len(pretrained_dict) / len(checkpoint.items()) * 100}% original pretrained model weights ") |
| 72 | |
| 73 | # 2. overwrite entries in the existing state dict |
| 74 | model_dict.update(pretrained_dict) |
| 75 | self.model.load_state_dict(model_dict) |
| 76 | self.logging.info(f"Successfully load model at {args.pretrained}!") |
| 77 | |
| 78 | # self.global_step = int(args.pretrained.split("/")[-1].split(".")[0]) # dirty implementation |
| 79 | |
| 80 | def training_step(self, batch, batch_idx): |
| 81 | x = batch["video"].to(self.device) |
| 82 | video_name = batch["video_name"] |
| 83 | |
| 84 | with torch.no_grad(): |
| 85 | b, _, _, _, _ = x.shape |
| 86 | x = rearrange(x, "b f c h w -> (b f) c h w").contiguous() |
| 87 | x = self.vae.encode(x).latent_dist.sample().mul_(0.18215) |