(
self,
batch: Dict,
N: int = 8,
ucg_keys: List[str] = None,
only_log_video_latents=False,
**kwargs,
)
| 280 | |
| 281 | @torch.no_grad() |
| 282 | def log_video( |
| 283 | self, |
| 284 | batch: Dict, |
| 285 | N: int = 8, |
| 286 | ucg_keys: List[str] = None, |
| 287 | only_log_video_latents=False, |
| 288 | **kwargs, |
| 289 | ) -> Dict: |
| 290 | conditioner_input_keys = [e.input_key for e in self.conditioner.embedders] |
| 291 | if ucg_keys: |
| 292 | assert all(map(lambda x: x in conditioner_input_keys, ucg_keys)), ( |
| 293 | "Each defined ucg key for sampling must be in the provided conditioner input keys," |
| 294 | f"but we have {ucg_keys} vs. {conditioner_input_keys}" |
| 295 | ) |
| 296 | else: |
| 297 | ucg_keys = conditioner_input_keys |
| 298 | log = dict() |
| 299 | |
| 300 | x = self.get_input(batch) |
| 301 | |
| 302 | c, uc = self.conditioner.get_unconditional_conditioning( |
| 303 | batch, |
| 304 | force_uc_zero_embeddings=ucg_keys if len(self.conditioner.embedders) > 0 else [], |
| 305 | ) |
| 306 | |
| 307 | sampling_kwargs = {} |
| 308 | |
| 309 | N = min(x.shape[0], N) |
| 310 | x = x.to(self.device)[:N] |
| 311 | if not self.latent_input: |
| 312 | log["inputs"] = x.to(torch.float32) |
| 313 | x = x.permute(0, 2, 1, 3, 4).contiguous() |
| 314 | z = self.encode_first_stage(x, batch) |
| 315 | if not only_log_video_latents: |
| 316 | log["reconstructions"] = self.decode_first_stage(z).to(torch.float32) |
| 317 | log["reconstructions"] = log["reconstructions"].permute(0, 2, 1, 3, 4).contiguous() |
| 318 | z = z.permute(0, 2, 1, 3, 4).contiguous() |
| 319 | |
| 320 | log.update(self.log_conditionings(batch, N)) |
| 321 | |
| 322 | for k in c: |
| 323 | if isinstance(c[k], torch.Tensor): |
| 324 | c[k], uc[k] = map(lambda y: y[k][:N].to(self.device), (c, uc)) |
| 325 | |
| 326 | if self.noised_image_input: |
| 327 | image = x[:, :, 0:1] |
| 328 | image = self.add_noise_to_first_frame(image) |
| 329 | image = self.encode_first_stage(image, batch) |
| 330 | image = image.permute(0, 2, 1, 3, 4).contiguous() |
| 331 | image = torch.concat([image, torch.zeros_like(z[:, 1:])], dim=1) |
| 332 | c["concat"] = image |
| 333 | uc["concat"] = image |
| 334 | samples = self.sample(c, shape=z.shape[1:], uc=uc, batch_size=N, **sampling_kwargs) # b t c h w |
| 335 | samples = samples.permute(0, 2, 1, 3, 4).contiguous() |
| 336 | if only_log_video_latents: |
| 337 | latents = 1.0 / self.scale_factor * samples |
| 338 | log["latents"] = latents |
| 339 | else: |
no test coverage detected