MCPcopy
hub / github.com/Robbyant/lingbot-world / sp_dit_forward_causal

Function sp_dit_forward_causal

wan/distributed/sequence_parallel.py:249–411  ·  view source on GitHub ↗

x: A list of videos each with shape [C, T, H, W]. t: [B]. context: A list of text embeddings each with shape [L, C]. seq_len: Maximum sequence length for positional encoding. y: Conditional video inputs for

(
    self,
    x,
    t,
    context,
    seq_len,
    y=None,
    dit_cond_dict=None,
    kv_cache=None,
    crossattn_cache=None,
    current_start=0,
    max_attention_size=1_000_000,
    frame_seqlen=None,
    cross_attn_first_call=None,
)

Source from the content-addressed store, hash-verified

247
248
249def sp_dit_forward_causal(
250 self,
251 x,
252 t,
253 context,
254 seq_len,
255 y=None,
256 dit_cond_dict=None,
257 kv_cache=None,
258 crossattn_cache=None,
259 current_start=0,
260 max_attention_size=1_000_000,
261 frame_seqlen=None,
262 cross_attn_first_call=None,
263):
264 """
265 x: A list of videos each with shape [C, T, H, W].
266 t: [B].
267 context: A list of text embeddings each with shape [L, C].
268 seq_len: Maximum sequence length for positional encoding.
269 y: Conditional video inputs for image-to-video mode, same shape as x.
270 dit_cond_dict: Dictionary of conditioning signals. May contain key ``c2ws_plucker_emb``
271 with camera Plucker embeddings of shape [B, C, F, H, W] for camera control.
272 kv_cache: Per-layer self-attention KV cache. Each dict contains keys ``k``, ``v``
273 (Tensor of shape [B, kv_size, local_heads, head_dim]), ``global_end_index``,
274 and ``local_end_index`` (scalar Tensors tracking cache position).
275 crossattn_cache: Per-layer cross-attention KV cache. Each dict contains keys ``k``, ``v``
276 (Tensor of shape [B, text_len, num_heads, head_dim]) and ``is_init`` (bool).
277 current_start: Token offset of the current chunk in the full sequence. Used to index
278 into the KV cache and compute positional embeddings correctly.
279 max_attention_size: Maximum number of KV tokens each query can attend to. Limits the
280 effective context window of self-attention to control memory usage.
281
282 Note: This function uses sequence parallel (SP) with Ulysses attention.
283 The sequence is chunked across GPUs. Inside attention, all-to-all redistributes
284 from sequence-parallel to head-parallel, applies causal RoPE on the full sequence,
285 updates the KV cache, runs attention, then all-to-all back.
286 """
287
288 assert len(x) == 1
289
290 if self.model_type == 'i2v':
291 assert y is not None
292 # params
293 device = self.patch_embedding.weight.device
294 if self.freqs.device != device:
295 self.freqs = self.freqs.to(device)
296
297 if y is not None:
298 x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
299
300 # embeddings
301 x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
302 grid_sizes = torch.stack(
303 [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
304 x = [u.flatten(2).transpose(1, 2) for u in x]
305 seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)[0]
306 assert seq_lens.max() <= seq_len

Callers

nothing calls this directly

Calls 7

get_world_sizeFunction · 0.85
get_rankFunction · 0.85
gather_forwardFunction · 0.85
toMethod · 0.80
sizeMethod · 0.80
sinusoidal_embedding_1dFunction · 0.50
unpatchifyMethod · 0.45

Tested by

no test coverage detected