MCPcopy
hub / github.com/Vchitect/Latte / training_step

Method training_step

train_pl.py:80–113  ·  view source on GitHub ↗
(self, batch, batch_idx)

Source from the content-addressed store, hash-verified

78 # self.global_step = int(args.pretrained.split("/")[-1].split(".")[0]) # dirty implementation
79
80 def training_step(self, batch, batch_idx):
81 x = batch["video"].to(self.device)
82 video_name = batch["video_name"]
83
84 with torch.no_grad():
85 b, _, _, _, _ = x.shape
86 x = rearrange(x, "b f c h w -> (b f) c h w").contiguous()
87 x = self.vae.encode(x).latent_dist.sample().mul_(0.18215)
88 x = rearrange(x, "(b f) c h w -> b f c h w", b=b).contiguous()
89
90 if self.args.extras == 78: # text-to-video
91 raise ValueError("T2V training is not supported at this moment!")
92 elif self.args.extras == 2:
93 model_kwargs = dict(y=video_name)
94 else:
95 model_kwargs = dict(y=None)
96
97 t = torch.randint(0, self.diffusion.num_timesteps, (x.shape[0],), device=self.device)
98 loss_dict = self.diffusion.training_losses(self.model, x, t, model_kwargs)
99 loss = loss_dict["loss"].mean()
100
101 if self.global_step < self.args.start_clip_iter:
102 gradient_norm = clip_grad_norm_(self.model.parameters(), self.args.clip_max_norm, clip_grad=False)
103 else:
104 gradient_norm = clip_grad_norm_(self.model.parameters(), self.args.clip_max_norm, clip_grad=True)
105
106 self.log("train_loss", loss)
107 self.log("gradient_norm", gradient_norm)
108
109 if (self.global_step+1) % self.args.log_every == 0:
110 self.logging.info(
111 f"(step={self.global_step+1:07d}/epoch={self.current_epoch:04d}) Train Loss: {loss:.4f}, Gradient Norm: {gradient_norm:.4f}"
112 )
113 return loss
114
115 def on_train_batch_end(self, *args, **kwargs):
116 update_ema(self.ema, self.model)

Callers

nothing calls this directly

Calls 5

clip_grad_norm_Function · 0.90
sampleMethod · 0.80
meanMethod · 0.80
encodeMethod · 0.45
training_lossesMethod · 0.45

Tested by

no test coverage detected