Create expert and data parallel groups based on MPU (model parallel) group. Note: Caller of this function is responsible to check if the groups already exist. Example - E + M + D parallel world_size = 16 model_degree = 2 expert_degree = 4 # number o
(expert_parallel_size_, mpu, use_data_before_expert_parallel_=False)
| 452 | |
| 453 | |
| 454 | def _create_expert_data_and_model_parallel(expert_parallel_size_, mpu, use_data_before_expert_parallel_=False): |
| 455 | """ |
| 456 | Create expert and data parallel groups based on MPU (model parallel) group. |
| 457 | |
| 458 | Note: Caller of this function is responsible to check if the groups already exist. |
| 459 | |
| 460 | Example - E + M + D parallel |
| 461 | world_size = 16 |
| 462 | model_degree = 2 |
| 463 | expert_degree = 4 # number of experts in same group |
| 464 | mp_group = [0, 1], [2,3], [4,5] ... |
| 465 | data_parallel_group =[0,2,4,6,8,10, 12,14], [1,3,5,7,9,11,13,15] |
| 466 | expert_parallel_group = [0,2,4,6], [8,10,12,14] [1,3,5,7], [9,11,13,15] |
| 467 | expert_data_parallel_group = [0,8],[2,10],[4,12],[6,14], [1,9],[3,11],[5,13],[7,15] |
| 468 | """ |
| 469 | assert dist.is_initialized(), "dist is not initialized" |
| 470 | tensor_parallel_size_ = bwc_tensor_model_parallel_world_size(mpu) |
| 471 | |
| 472 | global expert_tensor_parallel_world_size |
| 473 | expert_tensor_parallel_world_size = tensor_parallel_size_ |
| 474 | |
| 475 | world_size = dist.get_world_size() |
| 476 | rank = dist.get_rank() |
| 477 | dp_world_size = _get_data_parallel_world_size() |
| 478 | pp_world_size = 1 if mpu is None else bwc_pipeline_parallel_world_size(mpu) |
| 479 | |
| 480 | _ensure_divisibility(world_size, tensor_parallel_size_) |
| 481 | _ensure_divisibility(dp_world_size, expert_parallel_size_) |
| 482 | |
| 483 | log_dist( |
| 484 | f"Creating deepspeed groups with model parallel size {tensor_parallel_size_}, " |
| 485 | f"pipeline parallel size {pp_world_size}, expert parallel size {expert_parallel_size_}, " |
| 486 | f"world size {world_size}, dp world size {dp_world_size}", [0]) |
| 487 | |
| 488 | global _EXPERT_PARALLEL_GROUP, _EXPERT_DATA_PARALLEL_GROUP |
| 489 | global _EXPERT_PARALLEL_GROUP_RANKS, _EXPERT_DATA_PARALLEL_GROUP_RANKS |
| 490 | |
| 491 | group_name = f"ep_size_{expert_parallel_size_}" |
| 492 | |
| 493 | # Only create groups if they don't already exist |
| 494 | # Need to check conditions outside the group creation loop because of the way torch.dist group creation works |
| 495 | if group_name not in _EXPERT_DATA_PARALLEL_GROUP and group_name not in _EXPERT_PARALLEL_GROUP: |
| 496 | expert_parallel_groups, expert_data_parallel_groups = _get_expert_parallel_ranks( |
| 497 | world_size, tensor_parallel_size_, expert_parallel_size_, pp_world_size, use_data_before_expert_parallel_) |
| 498 | for ranks in expert_parallel_groups: |
| 499 | group = dist.new_group(ranks) |
| 500 | if rank in list(ranks): |
| 501 | _EXPERT_PARALLEL_GROUP[group_name] = group |
| 502 | _EXPERT_PARALLEL_GROUP_RANKS[group_name] = ranks |
| 503 | |
| 504 | for ranks in expert_data_parallel_groups: |
| 505 | group = dist.new_group(ranks) |
| 506 | if rank in list(ranks): |
| 507 | _EXPERT_DATA_PARALLEL_GROUP[group_name] = group |
| 508 | _EXPERT_DATA_PARALLEL_GROUP_RANKS[group_name] = ranks |
| 509 | |
| 510 | |
| 511 | def _get_max_expert_size(): |
nothing calls this directly
no test coverage detected
searching dependent graphs…