r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] context_lens(Tensor): Shape [B]
(self, x, context, context_lens)
| 200 | self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() |
| 201 | |
| 202 | def forward(self, x, context, context_lens): |
| 203 | r""" |
| 204 | Args: |
| 205 | x(Tensor): Shape [B, L1, C] |
| 206 | context(Tensor): Shape [B, L2, C] |
| 207 | context_lens(Tensor): Shape [B] |
| 208 | """ |
| 209 | image_context_length = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER |
| 210 | context_img = context[:, :image_context_length] |
| 211 | context = context[:, image_context_length:] |
| 212 | b, n, d = x.size(0), self.num_heads, self.head_dim |
| 213 | |
| 214 | # compute query, key, value |
| 215 | q = self.norm_q(self.q(x)).view(b, -1, n, d) |
| 216 | k = self.norm_k(self.k(context)).view(b, -1, n, d) |
| 217 | v = self.v(context).view(b, -1, n, d) |
| 218 | k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d) |
| 219 | v_img = self.v_img(context_img).view(b, -1, n, d) |
| 220 | img_x = flash_attention(q, k_img, v_img, k_lens=None) |
| 221 | # compute attention |
| 222 | x = flash_attention(q, k, v, k_lens=context_lens) |
| 223 | |
| 224 | # output |
| 225 | x = x.flatten(2) |
| 226 | img_x = img_x.flatten(2) |
| 227 | x = x + img_x |
| 228 | x = self.o(x) |
| 229 | return x |
| 230 | |
| 231 | |
| 232 | WAN_CROSSATTENTION_CLASSES = { |
nothing calls this directly
no test coverage detected