r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] context_lens(Tensor): Shape [B] cross_attn_first_call(bool, optional): If provided, used as the "first call this generation" gate instead of reading
(self, x, context, context_lens, crossattn_cache=None,
cross_attn_first_call=None)
| 170 | class WanCrossAttention(WanSelfAttention): |
| 171 | |
| 172 | def forward(self, x, context, context_lens, crossattn_cache=None, |
| 173 | cross_attn_first_call=None): |
| 174 | r""" |
| 175 | Args: |
| 176 | x(Tensor): Shape [B, L1, C] |
| 177 | context(Tensor): Shape [B, L2, C] |
| 178 | context_lens(Tensor): Shape [B] |
| 179 | cross_attn_first_call(bool, optional): If provided, used as the |
| 180 | "first call this generation" gate instead of reading |
| 181 | crossattn_cache["is_init"].item() (which forces a CPU↔GPU |
| 182 | sync). Caller (pipeline) tracks this as a Python bool. |
| 183 | """ |
| 184 | b, n, d = x.size(0), self.num_heads, self.head_dim |
| 185 | |
| 186 | # compute query, key, value |
| 187 | q = self.norm_q(self.q(x)).view(b, -1, n, d) |
| 188 | |
| 189 | if crossattn_cache is not None: |
| 190 | if cross_attn_first_call is None: |
| 191 | is_first = crossattn_cache["is_init"].item() == 0 |
| 192 | else: |
| 193 | is_first = cross_attn_first_call |
| 194 | if is_first: |
| 195 | crossattn_cache["is_init"].fill_(1) |
| 196 | k = self.norm_k(self.k(context)).view(b, -1, n, d) |
| 197 | v = self.v(context).view(b, -1, n, d) |
| 198 | crossattn_cache["k"].copy_(k) |
| 199 | crossattn_cache["v"].copy_(v) |
| 200 | else: |
| 201 | k = crossattn_cache["k"] |
| 202 | v = crossattn_cache["v"] |
| 203 | else: |
| 204 | k = self.norm_k(self.k(context)).view(b, -1, n, d) |
| 205 | v = self.v(context).view(b, -1, n, d) |
| 206 | |
| 207 | # compute attention |
| 208 | x = flash_attention(q, k, v, k_lens=context_lens) |
| 209 | |
| 210 | # output |
| 211 | x = x.flatten(2) |
| 212 | x = self.o(x) |
| 213 | return x |
| 214 | |
| 215 | |
| 216 | class CausalWanAttentionBlock(nn.Module): |
nothing calls this directly
no test coverage detected