r""" Forward pass through the diffusion model Args: x (List[Tensor]): List of input video tensors, each with shape [C_in, F, H, W] t (Tensor): Diffusion timesteps tensor of shape [B] context (List[Tensor]):
(
self,
x,
t,
context,
seq_len,
y=None,
dit_cond_dict=None,
)
| 434 | self.init_weights() |
| 435 | |
| 436 | def forward( |
| 437 | self, |
| 438 | x, |
| 439 | t, |
| 440 | context, |
| 441 | seq_len, |
| 442 | y=None, |
| 443 | dit_cond_dict=None, |
| 444 | ): |
| 445 | r""" |
| 446 | Forward pass through the diffusion model |
| 447 | |
| 448 | Args: |
| 449 | x (List[Tensor]): |
| 450 | List of input video tensors, each with shape [C_in, F, H, W] |
| 451 | t (Tensor): |
| 452 | Diffusion timesteps tensor of shape [B] |
| 453 | context (List[Tensor]): |
| 454 | List of text embeddings each with shape [L, C] |
| 455 | seq_len (`int`): |
| 456 | Maximum sequence length for positional encoding |
| 457 | y (List[Tensor], *optional*): |
| 458 | Conditional video inputs for image-to-video mode, same shape as x |
| 459 | |
| 460 | Returns: |
| 461 | List[Tensor]: |
| 462 | List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] |
| 463 | """ |
| 464 | if self.model_type == 'i2v': |
| 465 | assert y is not None |
| 466 | # params |
| 467 | device = self.patch_embedding.weight.device |
| 468 | if self.freqs.device != device: |
| 469 | self.freqs = self.freqs.to(device) |
| 470 | |
| 471 | if y is not None: |
| 472 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] |
| 473 | |
| 474 | # embeddings |
| 475 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] |
| 476 | grid_sizes = torch.stack( |
| 477 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) |
| 478 | x = [u.flatten(2).transpose(1, 2) for u in x] |
| 479 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) |
| 480 | assert seq_lens.max() <= seq_len |
| 481 | x = torch.cat([ |
| 482 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], |
| 483 | dim=1) for u in x |
| 484 | ]) |
| 485 | |
| 486 | # time embeddings |
| 487 | if t.dim() == 1: |
| 488 | t = t.expand(t.size(0), seq_len) |
| 489 | with torch.amp.autocast('cuda', dtype=torch.float32): |
| 490 | bt = t.size(0) |
| 491 | t = t.flatten() |
| 492 | e = self.time_embedding( |
| 493 | sinusoidal_embedding_1d(self.freq_dim, |
nothing calls this directly
no test coverage detected