create shard-group, replicate-group from config_file TODO: consider broadcast the config from rank0 Returns: MiCS_CommGroups
(
shard_size,
dp_group,
hierarchical_allgather=False,
mpu=None,
)
| 46 | |
| 47 | |
| 48 | def create_mics_comm_groups( |
| 49 | shard_size, |
| 50 | dp_group, |
| 51 | hierarchical_allgather=False, |
| 52 | mpu=None, |
| 53 | ): |
| 54 | """ |
| 55 | create shard-group, replicate-group from config_file |
| 56 | TODO: consider broadcast the config from rank0 |
| 57 | |
| 58 | Returns: |
| 59 | MiCS_CommGroups |
| 60 | """ |
| 61 | # env var for debugging purpose |
| 62 | ndevices_per_node = int(os.environ.get("NDEV_PER_NODE", get_accelerator().device_count())) |
| 63 | _log_rank0(f'creating MiCS communication groups with per node device size {ndevices_per_node}') |
| 64 | groups = MiCS_CommGroups() |
| 65 | |
| 66 | if mpu is not None: |
| 67 | assert dp_group == mpu.get_data_parallel_group() |
| 68 | |
| 69 | # full size of the world |
| 70 | world_size = dist.get_world_size() |
| 71 | # global rank |
| 72 | global_rank = dist.get_rank() |
| 73 | |
| 74 | config = _generate_mics_config(world_size, ndevices_per_node, shard_size, 1) |
| 75 | ranks_of_shard_group = config['shard_groups'] |
| 76 | ranks_of_repli_group = config['replicate_groups'] |
| 77 | if len(ranks_of_repli_group) == 0: |
| 78 | assert len(ranks_of_shard_group) == 1, "replicate groups are empty only for single shard group" |
| 79 | for r in ranks_of_shard_group[0]: |
| 80 | ranks_of_repli_group.append([r]) |
| 81 | |
| 82 | # for simplicity |
| 83 | assert _sizes_all_same(ranks_of_repli_group), "replicate groups must have the same size" |
| 84 | assert _sizes_all_same(ranks_of_shard_group), "shard groups must have the same size" |
| 85 | |
| 86 | assert sum([len(g) for g in ranks_of_shard_group]) == dist.get_world_size(), "all sharded ranks " |
| 87 | if len(ranks_of_shard_group) > 1: # if only shard on one group then no need for replicate groups |
| 88 | assert len(ranks_of_shard_group) == len( |
| 89 | ranks_of_repli_group[0]), "number of shard groups must equal to the size of each replicate group" |
| 90 | |
| 91 | global_rank = dist.get_rank() |
| 92 | # create shard groups |
| 93 | for shard_ranks in ranks_of_shard_group: |
| 94 | _group = dist.new_group(shard_ranks) |
| 95 | if global_rank in shard_ranks: |
| 96 | groups.param_shard_group = _group |
| 97 | groups.param_shard_size = len(shard_ranks) |
| 98 | groups.param_shard_rank = dist.get_rank(_group) |
| 99 | logger.info(f'rank {global_rank}, shard group' |
| 100 | f' {groups.param_shard_rank}/{dist.get_world_size(group=_group)}') |
| 101 | |
| 102 | # create replicate groups |
| 103 | for repli_ranks in ranks_of_repli_group: |
| 104 | if len(repli_ranks) > 1: |
| 105 | _group = dist.new_group(repli_ranks) |
no test coverage detected
searching dependent graphs…