(
self,
x: torch.Tensor,
motion_vec: torch.Tensor,
motion_mask: Optional[torch.Tensor] = None,
use_context_parallel=False,
)
| 332 | self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs) |
| 333 | |
| 334 | def forward( |
| 335 | self, |
| 336 | x: torch.Tensor, |
| 337 | motion_vec: torch.Tensor, |
| 338 | motion_mask: Optional[torch.Tensor] = None, |
| 339 | use_context_parallel=False, |
| 340 | ) -> torch.Tensor: |
| 341 | |
| 342 | B, T, N, C = motion_vec.shape |
| 343 | T_comp = T |
| 344 | |
| 345 | x_motion = self.pre_norm_motion(motion_vec) |
| 346 | x_feat = self.pre_norm_feat(x) |
| 347 | |
| 348 | kv = self.linear1_kv(x_motion) |
| 349 | q = self.linear1_q(x_feat) |
| 350 | |
| 351 | k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num) |
| 352 | q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num) |
| 353 | |
| 354 | # Apply QK-Norm if needed. |
| 355 | q = self.q_norm(q).to(v) |
| 356 | k = self.k_norm(k).to(v) |
| 357 | |
| 358 | k = rearrange(k, "B L N H D -> (B L) N H D") |
| 359 | v = rearrange(v, "B L N H D -> (B L) N H D") |
| 360 | |
| 361 | if use_context_parallel: |
| 362 | q = gather_forward(q, dim=1) |
| 363 | |
| 364 | q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp) |
| 365 | # Compute attention. |
| 366 | attn = attention( |
| 367 | q, |
| 368 | k, |
| 369 | v, |
| 370 | max_seqlen_q=q.shape[1], |
| 371 | batch_size=q.shape[0], |
| 372 | ) |
| 373 | |
| 374 | attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp) |
| 375 | if use_context_parallel: |
| 376 | attn = torch.chunk(attn, get_world_size(), dim=1)[get_rank()] |
| 377 | |
| 378 | output = self.linear2(attn) |
| 379 | |
| 380 | if motion_mask is not None: |
| 381 | output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1) |
| 382 | |
| 383 | return output |
nothing calls this directly
no test coverage detected