MCPcopy
hub / github.com/MeiGen-AI/InfiniteTalk / forward

Method forward

wan/modules/model.py:493–582  ·  view source on GitHub ↗

r""" Forward pass through the diffusion model Args: x (List[Tensor]): List of input video tensors, each with shape [C_in, F, H, W] t (Tensor): Diffusion timesteps tensor of shape [B] context (List[Tensor]):

(
        self,
        x,
        t,
        context,
        seq_len,
        clip_fea=None,
        y=None,
    )

Source from the content-addressed store, hash-verified

491 self.init_weights()
492
493 def forward(
494 self,
495 x,
496 t,
497 context,
498 seq_len,
499 clip_fea=None,
500 y=None,
501 ):
502 r"""
503 Forward pass through the diffusion model
504
505 Args:
506 x (List[Tensor]):
507 List of input video tensors, each with shape [C_in, F, H, W]
508 t (Tensor):
509 Diffusion timesteps tensor of shape [B]
510 context (List[Tensor]):
511 List of text embeddings each with shape [L, C]
512 seq_len (`int`):
513 Maximum sequence length for positional encoding
514 clip_fea (Tensor, *optional*):
515 CLIP image features for image-to-video mode or first-last-frame-to-video mode
516 y (List[Tensor], *optional*):
517 Conditional video inputs for image-to-video mode, same shape as x
518
519 Returns:
520 List[Tensor]:
521 List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
522 """
523 if self.model_type == 'i2v' or self.model_type == 'flf2v':
524 assert clip_fea is not None and y is not None
525 # params
526 device = self.patch_embedding.weight.device
527 if self.freqs.device != device:
528 self.freqs = self.freqs.to(device)
529
530 if y is not None:
531 x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
532
533 # embeddings
534 x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
535 grid_sizes = torch.stack(
536 [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
537 x = [u.flatten(2).transpose(1, 2) for u in x]
538 seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
539 assert seq_lens.max() <= seq_len
540 x = torch.cat([
541 torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
542 dim=1) for u in x
543 ])
544
545 # time embeddings
546 with amp.autocast(dtype=torch.float32):
547 e = self.time_embedding(
548 sinusoidal_embedding_1d(self.freq_dim, t).float())
549 e0 = self.time_projection(e).unflatten(1, (6, self.dim))
550 assert e.dtype == torch.float32 and e0.dtype == torch.float32

Callers

nothing calls this directly

Calls 2

unpatchifyMethod · 0.95
sinusoidal_embedding_1dFunction · 0.70

Tested by

no test coverage detected