r""" Sequence-parallel causal self-attention using Ulysses all-to-all. Input x is sequence-chunked (each GPU holds padded_seq_lens/world_size tokens). all-to-all gathers the full sequence and scatters heads, so each GPU operates on the full sequence with local_heads. After RoPE, pa
(
self,
x,
seq_lens,
grid_sizes,
freqs,
kv_cache=None,
current_start=0,
max_attention_size=1_000_000,
frame_seqlen=None,
seq_lens_int=None)
| 412 | |
| 413 | |
| 414 | def sp_attn_forward_causal( |
| 415 | self, |
| 416 | x, |
| 417 | seq_lens, |
| 418 | grid_sizes, |
| 419 | freqs, |
| 420 | kv_cache=None, |
| 421 | current_start=0, |
| 422 | max_attention_size=1_000_000, |
| 423 | frame_seqlen=None, |
| 424 | seq_lens_int=None): |
| 425 | r""" |
| 426 | Sequence-parallel causal self-attention using Ulysses all-to-all. |
| 427 | |
| 428 | Input x is sequence-chunked (each GPU holds padded_seq_lens/world_size tokens). |
| 429 | all-to-all gathers the full sequence and scatters heads, so each GPU |
| 430 | operates on the full sequence with local_heads. After RoPE, padding |
| 431 | tokens are sliced off so that only the actual ``seq_lens`` tokens enter |
| 432 | the KV cache and flash attention. The output is padded back to |
| 433 | ``padded_seq_lens`` before the reverse all-to-all. |
| 434 | |
| 435 | Args: |
| 436 | x(Tensor): Shape [B, padded_seq_lens // world_size, C] |
| 437 | seq_lens(Tensor): Total number of valid tokens (full sequence, not per-rank). |
| 438 | grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W). |
| 439 | freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]. |
| 440 | kv_cache(dict): Self-attention KV cache. Contains keys ``k``, ``v`` |
| 441 | (Tensor of shape [B, kv_size, local_heads, head_dim]), |
| 442 | ``global_end_index``, and ``local_end_index`` |
| 443 | (scalar Tensors tracking cache position). |
| 444 | current_start(int): Token offset of the current chunk in the full sequence. |
| 445 | Used to index into the KV cache and compute positional |
| 446 | embeddings correctly. |
| 447 | max_attention_size(int): Maximum number of KV tokens each query can attend to. |
| 448 | Limits the effective context window of self-attention |
| 449 | to control memory usage. |
| 450 | """ |
| 451 | b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim |
| 452 | sp_size = get_world_size() |
| 453 | |
| 454 | # query, key, value function |
| 455 | def qkv_fn(x): |
| 456 | q = self.norm_q(self.q(x)).view(b, s, n, d) |
| 457 | k = self.norm_k(self.k(x)).view(b, s, n, d) |
| 458 | v = self.v(x).view(b, s, n, d) |
| 459 | return q, k, v |
| 460 | |
| 461 | q, k, v = qkv_fn(x) |
| 462 | |
| 463 | # all-to-all: gather sequence, scatter heads |
| 464 | # [B, s/p, N, d] -> [B, s, N/p, d] |
| 465 | q = all_to_all(q, scatter_dim=2, gather_dim=1) |
| 466 | k = all_to_all(k, scatter_dim=2, gather_dim=1) |
| 467 | v = all_to_all(v, scatter_dim=2, gather_dim=1) |
| 468 | |
| 469 | # padded_seq_lens = s * sp_size may exceed seq_lens due to SP padding |
| 470 | padded_seq_lens = s * sp_size |
| 471 | if seq_lens_int is None: |
nothing calls this directly
no test coverage detected