MCPcopy
hub / github.com/Robbyant/lingbot-world / forward

Method forward

wan/modules/model.py:436–547  ·  view source on GitHub ↗

r""" Forward pass through the diffusion model Args: x (List[Tensor]): List of input video tensors, each with shape [C_in, F, H, W] t (Tensor): Diffusion timesteps tensor of shape [B] context (List[Tensor]):

(
        self,
        x,
        t,
        context,
        seq_len,
        y=None,
        dit_cond_dict=None,
    )

Source from the content-addressed store, hash-verified

434 self.init_weights()
435
436 def forward(
437 self,
438 x,
439 t,
440 context,
441 seq_len,
442 y=None,
443 dit_cond_dict=None,
444 ):
445 r"""
446 Forward pass through the diffusion model
447
448 Args:
449 x (List[Tensor]):
450 List of input video tensors, each with shape [C_in, F, H, W]
451 t (Tensor):
452 Diffusion timesteps tensor of shape [B]
453 context (List[Tensor]):
454 List of text embeddings each with shape [L, C]
455 seq_len (`int`):
456 Maximum sequence length for positional encoding
457 y (List[Tensor], *optional*):
458 Conditional video inputs for image-to-video mode, same shape as x
459
460 Returns:
461 List[Tensor]:
462 List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
463 """
464 if self.model_type == 'i2v':
465 assert y is not None
466 # params
467 device = self.patch_embedding.weight.device
468 if self.freqs.device != device:
469 self.freqs = self.freqs.to(device)
470
471 if y is not None:
472 x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
473
474 # embeddings
475 x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
476 grid_sizes = torch.stack(
477 [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
478 x = [u.flatten(2).transpose(1, 2) for u in x]
479 seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
480 assert seq_lens.max() <= seq_len
481 x = torch.cat([
482 torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
483 dim=1) for u in x
484 ])
485
486 # time embeddings
487 if t.dim() == 1:
488 t = t.expand(t.size(0), seq_len)
489 with torch.amp.autocast('cuda', dtype=torch.float32):
490 bt = t.size(0)
491 t = t.flatten()
492 e = self.time_embedding(
493 sinusoidal_embedding_1d(self.freq_dim,

Callers

nothing calls this directly

Calls 4

unpatchifyMethod · 0.95
toMethod · 0.80
sizeMethod · 0.80
sinusoidal_embedding_1dFunction · 0.70

Tested by

no test coverage detected