Generate expert parallel and expert data parallel group ranks list. Example - E + M + D parallel world_size = 16 model_degree = 2 expert_degree = 4 # number of experts in same group mp_group = [0, 1], [2,3], [4,5] ... data_parallel_group =[0,2,4,6,8,1
(world_size,
tensor_parallel_size_,
expert_parallel_size_,
pipeline_parallel_size_=1,
use_data_before_expert_parallel_=False)
| 383 | |
| 384 | |
| 385 | def _get_expert_parallel_ranks(world_size, |
| 386 | tensor_parallel_size_, |
| 387 | expert_parallel_size_, |
| 388 | pipeline_parallel_size_=1, |
| 389 | use_data_before_expert_parallel_=False): |
| 390 | """Generate expert parallel and expert data parallel group ranks list. |
| 391 | |
| 392 | Example - E + M + D parallel |
| 393 | world_size = 16 |
| 394 | model_degree = 2 |
| 395 | expert_degree = 4 # number of experts in same group |
| 396 | mp_group = [0, 1], [2,3], [4,5] ... |
| 397 | data_parallel_group =[0,2,4,6,8,10, 12,14], [1,3,5,7,9,11,13,15] |
| 398 | expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15] |
| 399 | expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15] |
| 400 | |
| 401 | Args: |
| 402 | world_size (int): Distributed world size. |
| 403 | tensor_parallel_size_ (int): Tensor parallel group size. |
| 404 | expert_parallel_size_ (int): Expert parallel group size. |
| 405 | pipeline_parallel_size_ (int): Pipeline parallel group size |
| 406 | use_data_before_expert_parallel_ (bool): Use the D + E instead of E + D topology |
| 407 | Returns: |
| 408 | Expert parallel group ranks and Expert data parallel group ranks list. |
| 409 | """ |
| 410 | _ensure_divisibility(world_size, tensor_parallel_size_ * pipeline_parallel_size_) |
| 411 | dp_world_size = world_size // (tensor_parallel_size_ * pipeline_parallel_size_) |
| 412 | _ensure_divisibility(dp_world_size, expert_parallel_size_) |
| 413 | |
| 414 | # Generate data parallel groups |
| 415 | data_parallel_groups = [] |
| 416 | dp_group_size = tensor_parallel_size_ |
| 417 | pp_stride = world_size // pipeline_parallel_size_ |
| 418 | |
| 419 | if use_data_before_expert_parallel_: |
| 420 | dp_stride = world_size // expert_parallel_size_ // tensor_parallel_size_ // pipeline_parallel_size_ |
| 421 | for pp_stage_start in range(0, world_size, pp_stride): |
| 422 | pp_stage_next = pp_stage_start + pp_stride |
| 423 | for i in range(dp_group_size): |
| 424 | data_parallel_groups.append(list()) |
| 425 | for ds in range(dp_stride): |
| 426 | # [0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30] |
| 427 | # [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31] |
| 428 | data_parallel_groups[-1].extend( |
| 429 | list( |
| 430 | range(pp_stage_start + i + ds * tensor_parallel_size_, pp_stage_next, |
| 431 | dp_stride * tensor_parallel_size_))) |
| 432 | else: |
| 433 | for pp_stage_start in range(0, world_size, pp_stride): |
| 434 | pp_stage_next = pp_stage_start + pp_stride |
| 435 | for i in range(dp_group_size): |
| 436 | data_parallel_groups.append(list(range(pp_stage_start + i, pp_stage_next, dp_group_size))) |
| 437 | |
| 438 | expert_parallel_groups = [] |
| 439 | expert_data_parallel_groups = [] |
| 440 | for dp_ranks in data_parallel_groups: |
| 441 | # partition of expert parallel groups, e.g. [0,2,4,6], [8,10,12,14] |
| 442 | part_ep_groups = [] |
searching dependent graphs…