MCPcopy Index your code
hub / github.com/huggingface/diffusers / ContextParallelConfig

Class ContextParallelConfig

src/diffusers/models/_modeling_parallel.py:42–154  ·  view source on GitHub ↗

Configuration for context parallelism. Args: ring_degree (`int`, *optional*, defaults to `1`): Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes attention between its local Q and KV chunks passed sequentially

Source from the content-addressed store, hash-verified

40
41@dataclass
42class ContextParallelConfig:
43 """
44 Configuration for context parallelism.
45
46 Args:
47 ring_degree (`int`, *optional*, defaults to `1`):
48 Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes
49 attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N
50 of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best
51 for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a
52 context parallel region. Must be a divisor of the total number of devices in the context parallel mesh.
53 ulysses_degree (`int`, *optional*, defaults to `1`):
54 Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes
55 local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all
56 KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with
57 good interconnect bandwidth.
58 convert_to_fp32 (`bool`, *optional*, defaults to `True`):
59 Whether to convert output and LSE to float32 for ring attention numerical stability.
60 rotate_method (`str`, *optional*, defaults to `"allgather"`):
61 Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
62 is supported.
63 ulysses_anything (`bool`, *optional*, defaults to `False`):
64 Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
65 are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
66 `ring_degree` must be 1.
67 ring_anything (`bool`, *optional*, defaults to `False`):
68 Whether to enable "Ring Anything" mode, which supports arbitrary sequence lengths. When enabled,
69 `ring_degree` must be greater than 1 and `ulysses_degree` must be 1.
70 mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
71 A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
72 creating a new one. This is useful when combining context parallelism with other parallelism strategies
73 (e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
74 "ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
75 `mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
76
77 """
78
79 ring_degree: int | None = None
80 ulysses_degree: int | None = None
81 convert_to_fp32: bool = True
82 # TODO: support alltoall
83 rotate_method: Literal["allgather", "alltoall"] = "allgather"
84 mesh: torch.distributed.device_mesh.DeviceMesh | None = None
85 # Whether to enable ulysses anything attention to support
86 # any sequence lengths and any head numbers.
87 ulysses_anything: bool = False
88 # Whether to enable ring anything attention to support any sequence lengths.
89 ring_anything: bool = False
90
91 _rank: int = None
92 _world_size: int = None
93 _device: torch.device = None
94 _mesh: torch.distributed.device_mesh.DeviceMesh = None
95 _flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None
96 _ring_mesh: torch.distributed.device_mesh.DeviceMesh = None
97 _ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None
98 _ring_local_rank: int = None
99 _ulysses_local_rank: int = None

Callers 3

_context_parallel_workerFunction · 0.90
_custom_mesh_workerFunction · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…