(self, args, logger: logging.Logger)
| 29 | |
| 30 | class LatteTrainingModule(LightningModule): |
| 31 | def __init__(self, args, logger: logging.Logger): |
| 32 | super(LatteTrainingModule, self).__init__() |
| 33 | self.args = args |
| 34 | self.logging = logger |
| 35 | self.model = get_models(args) |
| 36 | self.ema = deepcopy(self.model) |
| 37 | requires_grad(self.ema, False) |
| 38 | |
| 39 | # Load pretrained model if specified |
| 40 | if args.pretrained: |
| 41 | # Load old checkpoint, only load EMA |
| 42 | self._load_pretrained_parameters(args) |
| 43 | self.logging.info(f"Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}") |
| 44 | |
| 45 | self.diffusion = create_diffusion(timestep_respacing="") |
| 46 | self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae") |
| 47 | self.opt = torch.optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=0) |
| 48 | self.lr_scheduler = None |
| 49 | |
| 50 | # Freeze VAE |
| 51 | self.vae.requires_grad_(False) |
| 52 | |
| 53 | update_ema(self.ema, self.model, decay=0) # Ensure EMA is initialized with synced weights |
| 54 | self.model.train() # important! This enables embedding dropout for classifier-free guidance |
| 55 | self.ema.eval() |
| 56 | |
| 57 | def _load_pretrained_parameters(self, args): |
| 58 | checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage) |
nothing calls this directly
no test coverage detected