x: A list of videos each with shape [C, T, H, W]. t: [B]. context: A list of text embeddings each with shape [L, C]. seq_len: Maximum sequence length for positional encoding. y: Conditional video inputs for
(
self,
x,
t,
context,
seq_len,
y=None,
dit_cond_dict=None,
kv_cache=None,
crossattn_cache=None,
current_start=0,
max_attention_size=1_000_000,
frame_seqlen=None,
cross_attn_first_call=None,
)
| 247 | |
| 248 | |
| 249 | def sp_dit_forward_causal( |
| 250 | self, |
| 251 | x, |
| 252 | t, |
| 253 | context, |
| 254 | seq_len, |
| 255 | y=None, |
| 256 | dit_cond_dict=None, |
| 257 | kv_cache=None, |
| 258 | crossattn_cache=None, |
| 259 | current_start=0, |
| 260 | max_attention_size=1_000_000, |
| 261 | frame_seqlen=None, |
| 262 | cross_attn_first_call=None, |
| 263 | ): |
| 264 | """ |
| 265 | x: A list of videos each with shape [C, T, H, W]. |
| 266 | t: [B]. |
| 267 | context: A list of text embeddings each with shape [L, C]. |
| 268 | seq_len: Maximum sequence length for positional encoding. |
| 269 | y: Conditional video inputs for image-to-video mode, same shape as x. |
| 270 | dit_cond_dict: Dictionary of conditioning signals. May contain key ``c2ws_plucker_emb`` |
| 271 | with camera Plucker embeddings of shape [B, C, F, H, W] for camera control. |
| 272 | kv_cache: Per-layer self-attention KV cache. Each dict contains keys ``k``, ``v`` |
| 273 | (Tensor of shape [B, kv_size, local_heads, head_dim]), ``global_end_index``, |
| 274 | and ``local_end_index`` (scalar Tensors tracking cache position). |
| 275 | crossattn_cache: Per-layer cross-attention KV cache. Each dict contains keys ``k``, ``v`` |
| 276 | (Tensor of shape [B, text_len, num_heads, head_dim]) and ``is_init`` (bool). |
| 277 | current_start: Token offset of the current chunk in the full sequence. Used to index |
| 278 | into the KV cache and compute positional embeddings correctly. |
| 279 | max_attention_size: Maximum number of KV tokens each query can attend to. Limits the |
| 280 | effective context window of self-attention to control memory usage. |
| 281 | |
| 282 | Note: This function uses sequence parallel (SP) with Ulysses attention. |
| 283 | The sequence is chunked across GPUs. Inside attention, all-to-all redistributes |
| 284 | from sequence-parallel to head-parallel, applies causal RoPE on the full sequence, |
| 285 | updates the KV cache, runs attention, then all-to-all back. |
| 286 | """ |
| 287 | |
| 288 | assert len(x) == 1 |
| 289 | |
| 290 | if self.model_type == 'i2v': |
| 291 | assert y is not None |
| 292 | # params |
| 293 | device = self.patch_embedding.weight.device |
| 294 | if self.freqs.device != device: |
| 295 | self.freqs = self.freqs.to(device) |
| 296 | |
| 297 | if y is not None: |
| 298 | x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] |
| 299 | |
| 300 | # embeddings |
| 301 | x = [self.patch_embedding(u.unsqueeze(0)) for u in x] |
| 302 | grid_sizes = torch.stack( |
| 303 | [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) |
| 304 | x = [u.flatten(2).transpose(1, 2) for u in x] |
| 305 | seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)[0] |
| 306 | assert seq_lens.max() <= seq_len |
nothing calls this directly
no test coverage detected