MCPcopy
hub / github.com/Robbyant/lingbot-world / sp_attn_forward_causal

Function sp_attn_forward_causal

wan/distributed/sequence_parallel.py:414–550  ·  view source on GitHub ↗

r""" Sequence-parallel causal self-attention using Ulysses all-to-all. Input x is sequence-chunked (each GPU holds padded_seq_lens/world_size tokens). all-to-all gathers the full sequence and scatters heads, so each GPU operates on the full sequence with local_heads. After RoPE, pa

(
    self,
    x,
    seq_lens,
    grid_sizes,
    freqs,
    kv_cache=None,
    current_start=0,
    max_attention_size=1_000_000,
    frame_seqlen=None,
    seq_lens_int=None)

Source from the content-addressed store, hash-verified

412
413
414def sp_attn_forward_causal(
415 self,
416 x,
417 seq_lens,
418 grid_sizes,
419 freqs,
420 kv_cache=None,
421 current_start=0,
422 max_attention_size=1_000_000,
423 frame_seqlen=None,
424 seq_lens_int=None):
425 r"""
426 Sequence-parallel causal self-attention using Ulysses all-to-all.
427
428 Input x is sequence-chunked (each GPU holds padded_seq_lens/world_size tokens).
429 all-to-all gathers the full sequence and scatters heads, so each GPU
430 operates on the full sequence with local_heads. After RoPE, padding
431 tokens are sliced off so that only the actual ``seq_lens`` tokens enter
432 the KV cache and flash attention. The output is padded back to
433 ``padded_seq_lens`` before the reverse all-to-all.
434
435 Args:
436 x(Tensor): Shape [B, padded_seq_lens // world_size, C]
437 seq_lens(Tensor): Total number of valid tokens (full sequence, not per-rank).
438 grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W).
439 freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2].
440 kv_cache(dict): Self-attention KV cache. Contains keys ``k``, ``v``
441 (Tensor of shape [B, kv_size, local_heads, head_dim]),
442 ``global_end_index``, and ``local_end_index``
443 (scalar Tensors tracking cache position).
444 current_start(int): Token offset of the current chunk in the full sequence.
445 Used to index into the KV cache and compute positional
446 embeddings correctly.
447 max_attention_size(int): Maximum number of KV tokens each query can attend to.
448 Limits the effective context window of self-attention
449 to control memory usage.
450 """
451 b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
452 sp_size = get_world_size()
453
454 # query, key, value function
455 def qkv_fn(x):
456 q = self.norm_q(self.q(x)).view(b, s, n, d)
457 k = self.norm_k(self.k(x)).view(b, s, n, d)
458 v = self.v(x).view(b, s, n, d)
459 return q, k, v
460
461 q, k, v = qkv_fn(x)
462
463 # all-to-all: gather sequence, scatter heads
464 # [B, s/p, N, d] -> [B, s, N/p, d]
465 q = all_to_all(q, scatter_dim=2, gather_dim=1)
466 k = all_to_all(k, scatter_dim=2, gather_dim=1)
467 v = all_to_all(v, scatter_dim=2, gather_dim=1)
468
469 # padded_seq_lens = s * sp_size may exceed seq_lens due to SP padding
470 padded_seq_lens = s * sp_size
471 if seq_lens_int is None:

Callers

nothing calls this directly

Calls 7

get_world_sizeFunction · 0.85
all_to_allFunction · 0.85
flash_attentionFunction · 0.85
type_asMethod · 0.80
sizeMethod · 0.80
qkv_fnFunction · 0.70
causal_rope_applyFunction · 0.70

Tested by

no test coverage detected