MCPcopy
hub / github.com/Robbyant/lingbot-world / forward

Method forward

wan/modules/s2v/motioner.py:189–239  ·  view source on GitHub ↗
(self, x, seq_lens, grid_sizes, freqs)

Source from the content-addressed store, hash-verified

187class SwinSelfAttention(SelfAttention):
188
189 def forward(self, x, seq_lens, grid_sizes, freqs):
190 b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
191 assert b == 1, 'Only support batch_size 1'
192
193 # query, key, value function
194 def qkv_fn(x):
195 q = self.norm_q(self.q(x)).view(b, s, n, d)
196 k = self.norm_k(self.k(x)).view(b, s, n, d)
197 v = self.v(x).view(b, s, n, d)
198 return q, k, v
199
200 q, k, v = qkv_fn(x)
201
202 q = rope_apply(q, grid_sizes, freqs)
203 k = rope_apply(k, grid_sizes, freqs)
204 T, H, W = grid_sizes[0].tolist()
205
206 q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
207 k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
208 v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
209
210 ref_q = q[-1:]
211 q = q[:-1]
212
213 ref_k = repeat(
214 k[-1:], "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d
215 k = k[:-1]
216 k = torch.cat([k[:1], k, k[-1:]])
217 k = torch.cat([k[1:-1], k[2:], k[:-2], ref_k], dim=1) # (bt) (3hw) n d
218
219 ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=v.shape[0] - 1)
220 v = v[:-1]
221 v = torch.cat([v[:1], v, v[-1:]])
222 v = torch.cat([v[1:-1], v[2:], v[:-2], ref_v], dim=1)
223
224 # q: b (t h w) n d
225 # k: b (t h w) n d
226 out = flash_attention(
227 q=q,
228 k=k,
229 v=v,
230 # k_lens=torch.tensor([k.shape[1]] * k.shape[0], device=x.device, dtype=torch.long),
231 window_size=self.window_size)
232 out = torch.cat([out, ref_v[:1]], axis=0)
233 out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)
234 x = out
235
236 # output
237 x = x.flatten(2)
238 x = self.o(x)
239 return x
240
241
242#Fix the reference frame RoPE to 1,H,W.

Callers

nothing calls this directly

Calls 3

flash_attentionFunction · 0.85
qkv_fnFunction · 0.70
rope_applyFunction · 0.70

Tested by

no test coverage detected