MCPcopy
hub / github.com/zai-org/CogVideo / log_video

Method log_video

sat/diffusion_video.py:282–353  ·  view source on GitHub ↗
(
        self,
        batch: Dict,
        N: int = 8,
        ucg_keys: List[str] = None,
        only_log_video_latents=False,
        **kwargs,
    )

Source from the content-addressed store, hash-verified

280
281 @torch.no_grad()
282 def log_video(
283 self,
284 batch: Dict,
285 N: int = 8,
286 ucg_keys: List[str] = None,
287 only_log_video_latents=False,
288 **kwargs,
289 ) -> Dict:
290 conditioner_input_keys = [e.input_key for e in self.conditioner.embedders]
291 if ucg_keys:
292 assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), (
293 "Each defined ucg key for sampling must be in the provided conditioner input keys,"
294 f"but we have {ucg_keys} vs. {conditioner_input_keys}"
295 )
296 else:
297 ucg_keys = conditioner_input_keys
298 log = dict()
299
300 x = self.get_input(batch)
301
302 c, uc = self.conditioner.get_unconditional_conditioning(
303 batch,
304 force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [],
305 )
306
307 sampling_kwargs = {}
308
309 N = min(x.shape[0], N)
310 x = x.to(self.device)[:N]
311 if not self.latent_input:
312 log["inputs"] = x.to(torch.float32)
313 x = x.permute(0, 2, 1, 3, 4).contiguous()
314 z = self.encode_first_stage(x, batch)
315 if not only_log_video_latents:
316 log["reconstructions"] = self.decode_first_stage(z).to(torch.float32)
317 log["reconstructions"] = log["reconstructions"].permute(0, 2, 1, 3, 4).contiguous()
318 z = z.permute(0, 2, 1, 3, 4).contiguous()
319
320 log.update(self.log_conditionings(batch, N))
321
322 for k in c:
323 if isinstance(c[k], torch.Tensor):
324 c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc))
325
326 if self.noised_image_input:
327 image = x[:, :, 0:1]
328 image = self.add_noise_to_first_frame(image)
329 image = self.encode_first_stage(image, batch)
330 image = image.permute(0, 2, 1, 3, 4).contiguous()
331 image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1)
332 c["concat"] = image
333 uc["concat"] = image
334 samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w
335 samples = samples.permute(0, 2, 1, 3, 4).contiguous()
336 if only_log_video_latents:
337 latents = 1.0 / self.scale_factor * samples
338 log["latents"] = latents
339 else:

Callers 1

log_videoFunction · 0.80

Calls 8

get_inputMethod · 0.95
encode_first_stageMethod · 0.95
decode_first_stageMethod · 0.95
log_conditioningsMethod · 0.95
sampleMethod · 0.95
updateMethod · 0.45

Tested by

no test coverage detected