MCPcopy
hub / github.com/MeiGen-AI/InfiniteTalk / forward

Method forward

wan/modules/multitalk_model.py:189–213  ·  view source on GitHub ↗
(self, x, context, context_lens)

Source from the content-addressed store, hash-verified

187 self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
188
189 def forward(self, x, context, context_lens):
190 context_img = context[:, :257]
191 context = context[:, 257:]
192 b, n, d = x.size(0), self.num_heads, self.head_dim
193
194 # compute query, key, value
195 q = self.norm_q(self.q(x)).view(b, -1, n, d)
196 k = self.norm_k(self.k(context)).view(b, -1, n, d)
197 v = self.v(context).view(b, -1, n, d)
198 k_img = self.norm_k_img(self.k_img(context_img)).view(b, -1, n, d)
199 v_img = self.v_img(context_img).view(b, -1, n, d)
200 if USE_SAGEATTN:
201 img_x = sageattn(q, k_img, v_img, tensor_layout='NHD')
202 x = sageattn(q, k, v, tensor_layout='NHD')
203 else:
204 img_x = flash_attention(q, k_img, v_img, k_lens=None)
205 # compute attention
206 x = flash_attention(q, k, v, k_lens=context_lens)
207
208 # output
209 x = x.flatten(2)
210 img_x = img_x.flatten(2)
211 x = x + img_x
212 x = self.o(x)
213 return x
214
215
216class WanAttentionBlock(nn.Module):

Callers

nothing calls this directly

Calls 1

flash_attentionFunction · 0.85

Tested by

no test coverage detected