MCPcopy
hub / github.com/MeiGen-AI/InfiniteTalk / forward

Method forward

wan/modules/model.py:202–229  ·  view source on GitHub ↗

r""" Args: x(Tensor): Shape [B, L1, C] context(Tensor): Shape [B, L2, C] context_lens(Tensor): Shape [B]

(self, x, context, context_lens)

Source from the content-addressed store, hash-verified

200 self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
201
202 def forward(self, x, context, context_lens):
203 r"""
204 Args:
205 x(Tensor): Shape [B, L1, C]
206 context(Tensor): Shape [B, L2, C]
207 context_lens(Tensor): Shape [B]
208 """
209 image_context_length = context.shape[1] - T5_CONTEXT_TOKEN_NUMBER
210 context_img = context[:, :image_context_length]
211 context = context[:, image_context_length:]
212 b, n, d = x.size(0), self.num_heads, self.head_dim
213
214 # compute query, key, value
215 q = self.norm_q(self.q(x)).view(b, -1, n, d)
216 k = self.norm_k(self.k(context)).view(b, -1, n, d)
217 v = self.v(context).view(b, -1, n, d)
218 k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
219 v_img = self.v_img(context_img).view(b, -1, n, d)
220 img_x = flash_attention(q, k_img, v_img, k_lens=None)
221 # compute attention
222 x = flash_attention(q, k, v, k_lens=context_lens)
223
224 # output
225 x = x.flatten(2)
226 img_x = img_x.flatten(2)
227 x = x + img_x
228 x = self.o(x)
229 return x
230
231
232WAN_CROSSATTENTION_CLASSES = {

Callers

nothing calls this directly

Calls 1

flash_attentionFunction · 0.85

Tested by

no test coverage detected