r""" Forward pass through the diffusion model Args: x (List[Tensor]): List of input video tensors, each with shape [C_in, F, H, W] t (Tensor): Diffusion timesteps tensor of shape [B] context (List[Tensor]):
(
self,
x,
t,
context,
seq_len,
clip_fea=None,
y=None,
)
| 491 | self.init_weights() |
| 492 | |
| 493 | def forward( |
| 494 | self, |
| 495 | x, |
| 496 | t, |
| 497 | context, |
| 498 | seq_len, |
| 499 | clip_fea=None, |
| 500 | y=None, |
| 501 | ): |
| 502 | r""" |
| 503 | Forward pass through the diffusion model |
| 504 | |
| 505 | Args: |
| 506 | x (List[Tensor]): |
| 507 | List of input video tensors, each with shape [C_in, F, H, W] |
| 508 | t (Tensor): |
| 509 | Diffusion timesteps tensor of shape [B] |
| 510 | context (List[Tensor]): |
| 511 | List of text embeddings each with shape [L, C] |
| 512 | seq_len (`int`): |
| 513 | Maximum sequence length for positional encoding |
| 514 | clip_fea (Tensor, *optional*): |
| 515 | CLIP image features for image-to-video mode or first-last-frame-to-video mode |
| 516 | y (List[Tensor], *optional*): |
| 517 | Conditional video inputs for image-to-video mode, same shape as x |
| 518 | |
| 519 | Returns: |
| 520 | List[Tensor]: |
| 521 | List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] |
| 522 | """ |
| 523 | if self.model_type == 'i2v' or self.model_type == 'flf2v': |
| 524 | assert clip_fea is not None and y is not None |
| 525 | # params |
| 526 | device = self.patch_embedding.weight.device |
| 527 | if self.freqs.device != device: |
| 528 | self.freqs = self.freqs.to(device) |
| 529 | |
| 530 | if y is not None: |
| 531 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] |
| 532 | |
| 533 | # embeddings |
| 534 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] |
| 535 | grid_sizes = torch.stack( |
| 536 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) |
| 537 | x = [u.flatten(2).transpose(1, 2) for u in x] |
| 538 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) |
| 539 | assert seq_lens.max() <= seq_len |
| 540 | x = torch.cat([ |
| 541 | torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], |
| 542 | dim=1) for u in x |
| 543 | ]) |
| 544 | |
| 545 | # time embeddings |
| 546 | with amp.autocast(dtype=torch.float32): |
| 547 | e = self.time_embedding( |
| 548 | sinusoidal_embedding_1d(self.freq_dim, t).float()) |
| 549 | e0 = self.time_projection(e).unflatten(1, (6, self.dim)) |
| 550 | assert e.dtype == torch.float32 and e0.dtype == torch.float32 |
nothing calls this directly
no test coverage detected