MCPcopy
hub / github.com/Robbyant/lingbot-world / forward

Method forward

wan/modules/model_fast.py:172–213  ·  view source on GitHub ↗

r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] context_lens(Tensor): Shape [B] cross_attn_first_call(bool, optional): If provided, used as the "first call this generation" gate instead of reading

(self, x, context, context_lens, crossattn_cache=None,
                cross_attn_first_call=None)

Source from the content-addressed store, hash-verified

170class WanCrossAttention(WanSelfAttention):
171
172 def forward(self, x, context, context_lens, crossattn_cache=None,
173 cross_attn_first_call=None):
174 r"""
175 Args:
176 x(Tensor): Shape [B, L1, C]
177 context(Tensor): Shape [B, L2, C]
178 context_lens(Tensor): Shape [B]
179 cross_attn_first_call(bool, optional): If provided, used as the
180 "first call this generation" gate instead of reading
181 crossattn_cache["is_init"].item() (which forces a CPU↔GPU
182 sync). Caller (pipeline) tracks this as a Python bool.
183 """
184 b, n, d = x.size(0), self.num_heads, self.head_dim
185
186 # compute query, key, value
187 q = self.norm_q(self.q(x)).view(b, -1, n, d)
188
189 if crossattn_cache is not None:
190 if cross_attn_first_call is None:
191 is_first = crossattn_cache["is_init"].item() == 0
192 else:
193 is_first = cross_attn_first_call
194 if is_first:
195 crossattn_cache["is_init"].fill_(1)
196 k = self.norm_k(self.k(context)).view(b, -1, n, d)
197 v = self.v(context).view(b, -1, n, d)
198 crossattn_cache["k"].copy_(k)
199 crossattn_cache["v"].copy_(v)
200 else:
201 k = crossattn_cache["k"]
202 v = crossattn_cache["v"]
203 else:
204 k = self.norm_k(self.k(context)).view(b, -1, n, d)
205 v = self.v(context).view(b, -1, n, d)
206
207 # compute attention
208 x = flash_attention(q, k, v, k_lens=context_lens)
209
210 # output
211 x = x.flatten(2)
212 x = self.o(x)
213 return x
214
215
216class CausalWanAttentionBlock(nn.Module):

Callers

nothing calls this directly

Calls 2

flash_attentionFunction · 0.85
sizeMethod · 0.80

Tested by

no test coverage detected