MCPcopy
hub / github.com/Wan-Video/Wan2.2 / forward

Method forward

wan/modules/model.py:126–155  ·  view source on GitHub ↗

r""" Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] seq_lens(Tensor): Shape [B] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]

(self, x, seq_lens, grid_sizes, freqs)

Source from the content-addressed store, hash-verified

124 self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
125
126 def forward(self, x, seq_lens, grid_sizes, freqs):
127 r"""
128 Args:
129 x(Tensor): Shape [B, L, num_heads, C / num_heads]
130 seq_lens(Tensor): Shape [B]
131 grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
132 freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
133 """
134 b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
135
136 # query, key, value function
137 def qkv_fn(x):
138 q = self.norm_q(self.q(x)).view(b, s, n, d)
139 k = self.norm_k(self.k(x)).view(b, s, n, d)
140 v = self.v(x).view(b, s, n, d)
141 return q, k, v
142
143 q, k, v = qkv_fn(x)
144
145 x = flash_attention(
146 q=rope_apply(q, grid_sizes, freqs),
147 k=rope_apply(k, grid_sizes, freqs),
148 v=v,
149 k_lens=seq_lens,
150 window_size=self.window_size)
151
152 # output
153 x = x.flatten(2)
154 x = self.o(x)
155 return x
156
157
158class WanCrossAttention(WanSelfAttention):

Callers

nothing calls this directly

Calls 3

flash_attentionFunction · 0.85
rope_applyFunction · 0.70
qkv_fnFunction · 0.50

Tested by

no test coverage detected