(self, x, seq_lens, grid_sizes, freqs)
| 160 | self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity() |
| 161 | |
| 162 | def forward(self, x, seq_lens, grid_sizes, freqs): |
| 163 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim |
| 164 | |
| 165 | # query, key, value function |
| 166 | def qkv_fn(x): |
| 167 | q = self.norm_q(self.q(x)).view(b, s, n, d) |
| 168 | k = self.norm_k(self.k(x)).view(b, s, n, d) |
| 169 | v = self.v(x).view(b, s, n, d) |
| 170 | return q, k, v |
| 171 | |
| 172 | q, k, v = qkv_fn(x) |
| 173 | |
| 174 | x = flash_attention( |
| 175 | q=rope_apply(q, grid_sizes, freqs), |
| 176 | k=rope_apply(k, grid_sizes, freqs), |
| 177 | v=v, |
| 178 | k_lens=seq_lens, |
| 179 | window_size=self.window_size) |
| 180 | |
| 181 | # output |
| 182 | x = x.flatten(2) |
| 183 | x = self.o(x) |
| 184 | return x |
| 185 | |
| 186 | |
| 187 | class SwinSelfAttention(SelfAttention): |
nothing calls this directly
no test coverage detected