MCPcopy
hub / github.com/Robbyant/lingbot-world / forward

Method forward

wan/modules/s2v/model_s2v.py:151–180  ·  view source on GitHub ↗

Args: x(Tensor): Shape [B, L, num_heads, C / num_heads] seq_lens(Tensor): Shape [B] grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]

(self, x, seq_lens, grid_sizes, freqs)

Source from the content-addressed store, hash-verified

149class WanS2VSelfAttention(WanSelfAttention):
150
151 def forward(self, x, seq_lens, grid_sizes, freqs):
152 """
153 Args:
154 x(Tensor): Shape [B, L, num_heads, C / num_heads]
155 seq_lens(Tensor): Shape [B]
156 grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
157 freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
158 """
159 b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
160
161 # query, key, value function
162 def qkv_fn(x):
163 q = self.norm_q(self.q(x)).view(b, s, n, d)
164 k = self.norm_k(self.k(x)).view(b, s, n, d)
165 v = self.v(x).view(b, s, n, d)
166 return q, k, v
167
168 q, k, v = qkv_fn(x)
169
170 x = flash_attention(
171 q=rope_apply(q, grid_sizes, freqs),
172 k=rope_apply(k, grid_sizes, freqs),
173 v=v,
174 k_lens=seq_lens,
175 window_size=self.window_size)
176
177 # output
178 x = x.flatten(2)
179 x = self.o(x)
180 return x
181
182
183class WanS2VAttentionBlock(WanAttentionBlock):

Callers

nothing calls this directly

Calls 3

flash_attentionFunction · 0.85
qkv_fnFunction · 0.70
rope_applyFunction · 0.70

Tested by

no test coverage detected