MCPcopy
hub / github.com/Wan-Video/Wan2.2 / forward

Method forward

wan/modules/animate/face_blocks.py:334–383  ·  view source on GitHub ↗
(
        self,
        x: torch.Tensor,
        motion_vec: torch.Tensor,
        motion_mask: Optional[torch.Tensor] = None,
        use_context_parallel=False,
    )

Source from the content-addressed store, hash-verified

332 self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
333
334 def forward(
335 self,
336 x: torch.Tensor,
337 motion_vec: torch.Tensor,
338 motion_mask: Optional[torch.Tensor] = None,
339 use_context_parallel=False,
340 ) -> torch.Tensor:
341
342 B, T, N, C = motion_vec.shape
343 T_comp = T
344
345 x_motion = self.pre_norm_motion(motion_vec)
346 x_feat = self.pre_norm_feat(x)
347
348 kv = self.linear1_kv(x_motion)
349 q = self.linear1_q(x_feat)
350
351 k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
352 q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
353
354 # Apply QK-Norm if needed.
355 q = self.q_norm(q).to(v)
356 k = self.k_norm(k).to(v)
357
358 k = rearrange(k, "B L N H D -> (B L) N H D")
359 v = rearrange(v, "B L N H D -> (B L) N H D")
360
361 if use_context_parallel:
362 q = gather_forward(q, dim=1)
363
364 q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
365 # Compute attention.
366 attn = attention(
367 q,
368 k,
369 v,
370 max_seqlen_q=q.shape[1],
371 batch_size=q.shape[0],
372 )
373
374 attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
375 if use_context_parallel:
376 attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()]
377
378 output = self.linear2(attn)
379
380 if motion_mask is not None:
381 output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
382
383 return output

Callers

nothing calls this directly

Calls 5

gather_forwardFunction · 0.85
get_world_sizeFunction · 0.85
get_rankFunction · 0.85
toMethod · 0.80
attentionFunction · 0.70

Tested by

no test coverage detected