MCPcopy
hub / github.com/Vchitect/Latte / training_step

Method training_step

train_with_img_pl.py:80–123  ·  view source on GitHub ↗
(self, batch, batch_idx)

Source from the content-addressed store, hash-verified

78 # self.global_step = int(args.pretrained.split("/")[-1].split(".")[0]) # dirty implementation
79
80 def training_step(self, batch, batch_idx):
81 x = batch["video"].to(self.device)
82 video_name = batch["video_name"]
83
84 if self.args.dataset == "ucf101_img":
85 image_name = batch['image_name']
86 image_names = []
87 for caption in image_name:
88 single_caption = [int(item) for item in caption.split('=====')]
89 image_names.append(torch.as_tensor(single_caption))
90
91 with torch.no_grad():
92 b, _, _, _, _ = x.shape
93 x = rearrange(x, "b f c h w -> (b f) c h w").contiguous()
94 x = self.vae.encode(x).latent_dist.sample().mul_(0.18215)
95 x = rearrange(x, "(b f) c h w -> b f c h w", b=b).contiguous()
96
97 if self.args.extras == 78: # text-to-video
98 raise ValueError("T2V training is not supported at this moment!")
99 elif self.args.extras == 2:
100 if self.args.dataset == "ucf101_img":
101 model_kwargs = dict(y=video_name, y_image=image_names, use_image_num=self.args.use_image_num)
102 else:
103 model_kwargs = dict(y=video_name)
104 else:
105 model_kwargs = dict(y=None, use_image_num=self.args.use_image_num)
106
107 t = torch.randint(0, self.diffusion.num_timesteps, (x.shape[0],), device=self.device)
108 loss_dict = self.diffusion.training_losses(self.model, x, t, model_kwargs)
109 loss = loss_dict["loss"].mean()
110
111 if self.global_step < self.args.start_clip_iter:
112 gradient_norm = clip_grad_norm_(self.model.parameters(), self.args.clip_max_norm, clip_grad=False)
113 else:
114 gradient_norm = clip_grad_norm_(self.model.parameters(), self.args.clip_max_norm, clip_grad=True)
115
116 self.log("train_loss", loss)
117 self.log("gradient_norm", gradient_norm)
118
119 if (self.global_step+1) % self.args.log_every == 0:
120 self.logging.info(
121 f"(step={self.global_step+1:07d}/epoch={self.current_epoch:04d}) Train Loss: {loss:.4f}, Gradient Norm: {gradient_norm:.4f}"
122 )
123 return loss
124
125 def on_train_batch_end(self, *args, **kwargs):
126 update_ema(self.ema, self.model)

Callers

nothing calls this directly

Calls 6

clip_grad_norm_Function · 0.90
appendMethod · 0.80
sampleMethod · 0.80
meanMethod · 0.80
encodeMethod · 0.45
training_lossesMethod · 0.45

Tested by

no test coverage detected