Configuration for context parallelism. Args: ring_degree (`int`, *optional*, defaults to `1`): Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes attention between its local Q and KV chunks passed sequentially
| 40 | |
| 41 | @dataclass |
| 42 | class ContextParallelConfig: |
| 43 | """ |
| 44 | Configuration for context parallelism. |
| 45 | |
| 46 | Args: |
| 47 | ring_degree (`int`, *optional*, defaults to `1`): |
| 48 | Number of devices to use for Ring Attention. Sequence is split across devices. Each device computes |
| 49 | attention between its local Q and KV chunks passed sequentially around ring. Lower memory (only holds 1/N |
| 50 | of KV at a time), overlaps compute with communication, but requires N iterations to see all tokens. Best |
| 51 | for long sequences with limited memory/bandwidth. Number of devices to use for ring attention within a |
| 52 | context parallel region. Must be a divisor of the total number of devices in the context parallel mesh. |
| 53 | ulysses_degree (`int`, *optional*, defaults to `1`): |
| 54 | Number of devices to use for Ulysses Attention. Sequence split is across devices. Each device computes |
| 55 | local QKV, then all-gathers all KV chunks to compute full attention in one pass. Higher memory (stores all |
| 56 | KV), requires high-bandwidth all-to-all communication, but lower latency. Best for moderate sequences with |
| 57 | good interconnect bandwidth. |
| 58 | convert_to_fp32 (`bool`, *optional*, defaults to `True`): |
| 59 | Whether to convert output and LSE to float32 for ring attention numerical stability. |
| 60 | rotate_method (`str`, *optional*, defaults to `"allgather"`): |
| 61 | Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"` |
| 62 | is supported. |
| 63 | ulysses_anything (`bool`, *optional*, defaults to `False`): |
| 64 | Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that |
| 65 | are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and |
| 66 | `ring_degree` must be 1. |
| 67 | ring_anything (`bool`, *optional*, defaults to `False`): |
| 68 | Whether to enable "Ring Anything" mode, which supports arbitrary sequence lengths. When enabled, |
| 69 | `ring_degree` must be greater than 1 and `ulysses_degree` must be 1. |
| 70 | mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*): |
| 71 | A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of |
| 72 | creating a new one. This is useful when combining context parallelism with other parallelism strategies |
| 73 | (e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and |
| 74 | "ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with |
| 75 | `mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP). |
| 76 | |
| 77 | """ |
| 78 | |
| 79 | ring_degree: int | None = None |
| 80 | ulysses_degree: int | None = None |
| 81 | convert_to_fp32: bool = True |
| 82 | # TODO: support alltoall |
| 83 | rotate_method: Literal["allgather", "alltoall"] = "allgather" |
| 84 | mesh: torch.distributed.device_mesh.DeviceMesh | None = None |
| 85 | # Whether to enable ulysses anything attention to support |
| 86 | # any sequence lengths and any head numbers. |
| 87 | ulysses_anything: bool = False |
| 88 | # Whether to enable ring anything attention to support any sequence lengths. |
| 89 | ring_anything: bool = False |
| 90 | |
| 91 | _rank: int = None |
| 92 | _world_size: int = None |
| 93 | _device: torch.device = None |
| 94 | _mesh: torch.distributed.device_mesh.DeviceMesh = None |
| 95 | _flattened_mesh: torch.distributed.device_mesh.DeviceMesh = None |
| 96 | _ring_mesh: torch.distributed.device_mesh.DeviceMesh = None |
| 97 | _ulysses_mesh: torch.distributed.device_mesh.DeviceMesh = None |
| 98 | _ring_local_rank: int = None |
| 99 | _ulysses_local_rank: int = None |
no outgoing calls
no test coverage detected
searching dependent graphs…