Create expert and data parallel groups. When mp_size is None or 1: legacy consecutive ordering (backward compatible). When mp_size > 1 and mp_mode=="tp": TP-strided rank ordering. When mp_size > 1 and mp_mode=="sp": consecutive rank ordering. Note: Caller of this function is respon
(expert_parallel_size_,
mp_size=None,
pp_size=None,
mp_mode="tp",
use_data_before_expert_parallel_=False)
| 238 | |
| 239 | |
| 240 | def _create_expert_and_data_parallel(expert_parallel_size_, |
| 241 | mp_size=None, |
| 242 | pp_size=None, |
| 243 | mp_mode="tp", |
| 244 | use_data_before_expert_parallel_=False): |
| 245 | """Create expert and data parallel groups. |
| 246 | |
| 247 | When mp_size is None or 1: legacy consecutive ordering (backward compatible). |
| 248 | When mp_size > 1 and mp_mode=="tp": TP-strided rank ordering. |
| 249 | When mp_size > 1 and mp_mode=="sp": consecutive rank ordering. |
| 250 | |
| 251 | Note: Caller of this function is responsible to check if the groups already exist. |
| 252 | |
| 253 | Example - E + D parallel (legacy path) |
| 254 | world_size = 16 |
| 255 | expert_parallel_size = 2 # number of experts in same group |
| 256 | expert_data_parallel_group = [0,2,4,6,8,10,12,14], [1,3,5,7,9,11,13,15] - all reduce is only on MoE params |
| 257 | expert_parallel_group = [0, 1], [2,3], [4,5], [6,7], [8,9] - no all reduce, but all to all |
| 258 | data_parallel_group = [0,1,...,15] - all reduce is only on non-MoE |
| 259 | |
| 260 | Args: |
| 261 | expert_parallel_size_ (int): Expert parallel group size. |
| 262 | mp_size (int, optional): Model parallel size (TP or SP). None treated as 1. |
| 263 | pp_size (int, optional): Pipeline parallel size. None falls back to mpu. |
| 264 | mp_mode (str): "tp" for TP-strided ordering, "sp" for consecutive ordering. |
| 265 | use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology. |
| 266 | """ |
| 267 | assert dist.is_initialized() |
| 268 | |
| 269 | # Resolve parameters for backward compat |
| 270 | effective_mp_size = 1 if mp_size is None else mp_size |
| 271 | |
| 272 | log_dist(f'Creating expert and data parallel groups with size {expert_parallel_size_}', ranks=[0]) |
| 273 | world_size = dist.get_world_size() |
| 274 | |
| 275 | # Resolve pp_size |
| 276 | if pp_size is not None: |
| 277 | pp_world_size = pp_size |
| 278 | else: |
| 279 | pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu) |
| 280 | |
| 281 | rank = dist.get_rank() |
| 282 | |
| 283 | pp_stride = world_size // pp_world_size |
| 284 | _ensure_divisibility(pp_stride, expert_parallel_size_) |
| 285 | |
| 286 | group_name = f"ep_size_{expert_parallel_size_}" |
| 287 | |
| 288 | global _EXPERT_DATA_PARALLEL_GROUP |
| 289 | global _EXPERT_DATA_PARALLEL_GROUP_RANKS |
| 290 | global _EXPERT_PARALLEL_GROUP |
| 291 | global _EXPERT_PARALLEL_GROUP_RANKS |
| 292 | |
| 293 | # Legacy path: mp_size <= 1 (preserves exact original behavior) |
| 294 | if effective_mp_size <= 1: |
| 295 | ep_stride = pp_stride // expert_parallel_size_ |
| 296 | |
| 297 | # Build the expert data parallel groups. |
nothing calls this directly
no test coverage detected
searching dependent graphs…