MCPcopy
hub / github.com/Vchitect/Latte / __init__

Method __init__

train_with_img_pl.py:31–55  ·  view source on GitHub ↗
(self, args, logger: logging.Logger)

Source from the content-addressed store, hash-verified

29
30class LatteTrainingModule(LightningModule):
31 def __init__(self, args, logger: logging.Logger):
32 super(LatteTrainingModule, self).__init__()
33 self.args = args
34 self.logging = logger
35 self.model = get_models(args)
36 self.ema = deepcopy(self.model)
37 requires_grad(self.ema, False)
38
39 # Load pretrained model if specified
40 if args.pretrained:
41 # Load old checkpoint, only load EMA
42 self._load_pretrained_parameters(args)
43 self.logging.info(f"Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
44
45 self.diffusion = create_diffusion(timestep_respacing="")
46 self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
47 self.opt = torch.optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=0)
48 self.lr_scheduler = None
49
50 # Freeze VAE
51 self.vae.requires_grad_(False)
52
53 update_ema(self.ema, self.model, decay=0) # Ensure EMA is initialized with synced weights
54 self.model.train() # important! This enables embedding dropout for classifier-free guidance
55 self.ema.eval()
56
57 def _load_pretrained_parameters(self, args):
58 checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage)

Callers

nothing calls this directly

Calls 5

get_modelsFunction · 0.90
requires_gradFunction · 0.90
create_diffusionFunction · 0.90
update_emaFunction · 0.90

Tested by

no test coverage detected