MCPcopy
hub / github.com/Vchitect/Latte / LatteTrainingModule

Class LatteTrainingModule

train_pl.py:30–136  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

28
29
30class LatteTrainingModule(LightningModule):
31 def __init__(self, args, logger: logging.Logger):
32 super(LatteTrainingModule, self).__init__()
33 self.args = args
34 self.logging = logger
35 self.model = get_models(args)
36 self.ema = deepcopy(self.model)
37 requires_grad(self.ema, False)
38
39 # Load pretrained model if specified
40 if args.pretrained:
41 # Load old checkpoint, only load EMA
42 self._load_pretrained_parameters(args)
43 self.logging.info(f"Model Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
44
45 self.diffusion = create_diffusion(timestep_respacing="")
46 self.vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
47 self.opt = torch.optim.AdamW(self.model.parameters(), lr=1e-4, weight_decay=0)
48 self.lr_scheduler = None
49
50 # Freeze VAE
51 self.vae.requires_grad_(False)
52
53 update_ema(self.ema, self.model, decay=0) # Ensure EMA is initialized with synced weights
54 self.model.train() # important! This enables embedding dropout for classifier-free guidance
55 self.ema.eval()
56
57 def _load_pretrained_parameters(self, args):
58 checkpoint = torch.load(args.pretrained, map_location=lambda storage, loc: storage)
59 if "ema" in checkpoint: # supports checkpoints from train.py
60 self.logging.info("Using ema ckpt!")
61 checkpoint = checkpoint["ema"]
62
63 model_dict = self.model.state_dict()
64 # 1. filter out unnecessary keys
65 pretrained_dict = {}
66 for k, v in checkpoint.items():
67 if k in model_dict:
68 pretrained_dict[k] = v
69 else:
70 self.logging.info("Ignoring: {}".format(k))
71 self.logging.info(f"Successfully Load {len(pretrained_dict) / len(checkpoint.items()) * 100}% original pretrained model weights ")
72
73 # 2. overwrite entries in the existing state dict
74 model_dict.update(pretrained_dict)
75 self.model.load_state_dict(model_dict)
76 self.logging.info(f"Successfully load model at {args.pretrained}!")
77
78 # self.global_step = int(args.pretrained.split("/")[-1].split(".")[0]) # dirty implementation
79
80 def training_step(self, batch, batch_idx):
81 x = batch["video"].to(self.device)
82 video_name = batch["video_name"]
83
84 with torch.no_grad():
85 b, _, _, _, _ = x.shape
86 x = rearrange(x, "b f c h w -> (b f) c h w").contiguous()
87 x = self.vae.encode(x).latent_dist.sample().mul_(0.18215)

Callers 1

mainFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected