Configuration for applying different parallelisms. Args: context_parallel_config (`ContextParallelConfig`, *optional*): Configuration for context parallelism.
| 156 | |
| 157 | @dataclass |
| 158 | class ParallelConfig: |
| 159 | """ |
| 160 | Configuration for applying different parallelisms. |
| 161 | |
| 162 | Args: |
| 163 | context_parallel_config (`ContextParallelConfig`, *optional*): |
| 164 | Configuration for context parallelism. |
| 165 | """ |
| 166 | |
| 167 | context_parallel_config: ContextParallelConfig | None = None |
| 168 | |
| 169 | _rank: int = None |
| 170 | _world_size: int = None |
| 171 | _device: torch.device = None |
| 172 | _mesh: torch.distributed.device_mesh.DeviceMesh = None |
| 173 | |
| 174 | def setup( |
| 175 | self, |
| 176 | rank: int, |
| 177 | world_size: int, |
| 178 | device: torch.device, |
| 179 | *, |
| 180 | mesh: torch.distributed.device_mesh.DeviceMesh | None = None, |
| 181 | ): |
| 182 | self._rank = rank |
| 183 | self._world_size = world_size |
| 184 | self._device = device |
| 185 | self._mesh = mesh |
| 186 | if self.context_parallel_config is not None: |
| 187 | self.context_parallel_config.setup(rank, world_size, device, mesh) |
| 188 | |
| 189 | |
| 190 | @dataclass(frozen=True) |
no outgoing calls
no test coverage detected
searching dependent graphs…