(self, x, context, context_lens)
| 187 | self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() |
| 188 | |
| 189 | def forward(self, x, context, context_lens): |
| 190 | context_img = context[:, :257] |
| 191 | context = context[:, 257:] |
| 192 | b, n, d = x.size(0), self.num_heads, self.head_dim |
| 193 | |
| 194 | # compute query, key, value |
| 195 | q = self.norm_q(self.q(x)).view(b, -1, n, d) |
| 196 | k = self.norm_k(self.k(context)).view(b, -1, n, d) |
| 197 | v = self.v(context).view(b, -1, n, d) |
| 198 | k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) |
| 199 | v_img = self.v_img(context_img).view(b, -1, n, d) |
| 200 | if USE_SAGEATTN: |
| 201 | img_x = sageattn(q, k_img, v_img, tensor_layout='NHD') |
| 202 | x = sageattn(q, k, v, tensor_layout='NHD') |
| 203 | else: |
| 204 | img_x = flash_attention(q, k_img, v_img, k_lens=None) |
| 205 | # compute attention |
| 206 | x = flash_attention(q, k, v, k_lens=context_lens) |
| 207 | |
| 208 | # output |
| 209 | x = x.flatten(2) |
| 210 | img_x = img_x.flatten(2) |
| 211 | x = x + img_x |
| 212 | x = self.o(x) |
| 213 | return x |
| 214 | |
| 215 | |
| 216 | class WanAttentionBlock(nn.Module): |
nothing calls this directly
no test coverage detected