(self, x, seq_lens, grid_sizes, freqs)
| 187 | class SwinSelfAttention(SelfAttention): |
| 188 | |
| 189 | def forward(self, x, seq_lens, grid_sizes, freqs): |
| 190 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim |
| 191 | assert b == 1, 'Only support batch_size 1' |
| 192 | |
| 193 | # query, key, value function |
| 194 | def qkv_fn(x): |
| 195 | q = self.norm_q(self.q(x)).view(b, s, n, d) |
| 196 | k = self.norm_k(self.k(x)).view(b, s, n, d) |
| 197 | v = self.v(x).view(b, s, n, d) |
| 198 | return q, k, v |
| 199 | |
| 200 | q, k, v = qkv_fn(x) |
| 201 | |
| 202 | q = rope_apply(q, grid_sizes, freqs) |
| 203 | k = rope_apply(k, grid_sizes, freqs) |
| 204 | T, H, W = grid_sizes[0].tolist() |
| 205 | |
| 206 | q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) |
| 207 | k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) |
| 208 | v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W) |
| 209 | |
| 210 | ref_q = q[-1:] |
| 211 | q = q[:-1] |
| 212 | |
| 213 | ref_k = repeat( |
| 214 | k[-1:], "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d |
| 215 | k = k[:-1] |
| 216 | k = torch.cat([k[:1], k, k[-1:]]) |
| 217 | k = torch.cat([k[1:-1], k[2:], k[:-2], ref_k], dim=1) # (bt) (3hw) n d |
| 218 | |
| 219 | ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=v.shape[0] - 1) |
| 220 | v = v[:-1] |
| 221 | v = torch.cat([v[:1], v, v[-1:]]) |
| 222 | v = torch.cat([v[1:-1], v[2:], v[:-2], ref_v], dim=1) |
| 223 | |
| 224 | # q: b (t h w) n d |
| 225 | # k: b (t h w) n d |
| 226 | out = flash_attention( |
| 227 | q=q, |
| 228 | k=k, |
| 229 | v=v, |
| 230 | # k_lens=torch.tensor([k.shape[1]] * k.shape[0], device=x.device, dtype=torch.long), |
| 231 | window_size=self.window_size) |
| 232 | out = torch.cat([out, ref_v[:1]], axis=0) |
| 233 | out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W) |
| 234 | x = out |
| 235 | |
| 236 | # output |
| 237 | x = x.flatten(2) |
| 238 | x = self.o(x) |
| 239 | return x |
| 240 | |
| 241 | |
| 242 | #Fix the reference frame RoPE to 1,H,W. |
nothing calls this directly
no test coverage detected