MCPcopy
hub / github.com/deepspeedai/DeepSpeed / create_mics_comm_groups

Function create_mics_comm_groups

deepspeed/runtime/zero/mics_utils.py:48–160  ·  view source on GitHub ↗

create shard-group, replicate-group from config_file TODO: consider broadcast the config from rank0 Returns: MiCS_CommGroups

(
    shard_size,
    dp_group,
    hierarchical_allgather=False,
    mpu=None,
)

Source from the content-addressed store, hash-verified

46
47
48def create_mics_comm_groups(
49 shard_size,
50 dp_group,
51 hierarchical_allgather=False,
52 mpu=None,
53):
54 """
55 create shard-group, replicate-group from config_file
56 TODO: consider broadcast the config from rank0
57
58 Returns:
59 MiCS_CommGroups
60 """
61 # env var for debugging purpose
62 ndevices_per_node = int(os.environ.get("NDEV_PER_NODE", get_accelerator().device_count()))
63 _log_rank0(f'creating MiCS communication groups with per node device size {ndevices_per_node}')
64 groups = MiCS_CommGroups()
65
66 if mpu is not None:
67 assert dp_group == mpu.get_data_parallel_group()
68
69 # full size of the world
70 world_size = dist.get_world_size()
71 # global rank
72 global_rank = dist.get_rank()
73
74 config = _generate_mics_config(world_size, ndevices_per_node, shard_size, 1)
75 ranks_of_shard_group = config['shard_groups']
76 ranks_of_repli_group = config['replicate_groups']
77 if len(ranks_of_repli_group) == 0:
78 assert len(ranks_of_shard_group) == 1, "replicate groups are empty only for single shard group"
79 for r in ranks_of_shard_group[0]:
80 ranks_of_repli_group.append([r])
81
82 # for simplicity
83 assert _sizes_all_same(ranks_of_repli_group), "replicate groups must have the same size"
84 assert _sizes_all_same(ranks_of_shard_group), "shard groups must have the same size"
85
86 assert sum([len(g) for g in ranks_of_shard_group]) == dist.get_world_size(), "all sharded ranks "
87 if len(ranks_of_shard_group) > 1: # if only shard on one group then no need for replicate groups
88 assert len(ranks_of_shard_group) == len(
89 ranks_of_repli_group[0]), "number of shard groups must equal to the size of each replicate group"
90
91 global_rank = dist.get_rank()
92 # create shard groups
93 for shard_ranks in ranks_of_shard_group:
94 _group = dist.new_group(shard_ranks)
95 if global_rank in shard_ranks:
96 groups.param_shard_group = _group
97 groups.param_shard_size = len(shard_ranks)
98 groups.param_shard_rank = dist.get_rank(_group)
99 logger.info(f'rank {global_rank}, shard group'
100 f' {groups.param_shard_rank}/{dist.get_world_size(group=_group)}')
101
102 # create replicate groups
103 for repli_ranks in ranks_of_repli_group:
104 if len(repli_ranks) > 1:
105 _group = dist.new_group(repli_ranks)

Callers 1

__init__Method · 0.90

Calls 12

get_acceleratorFunction · 0.90
_log_rank0Function · 0.85
MiCS_CommGroupsClass · 0.85
_generate_mics_configFunction · 0.85
_sizes_all_sameFunction · 0.85
get_world_sizeMethod · 0.80
appendMethod · 0.80
getMethod · 0.45
device_countMethod · 0.45
get_rankMethod · 0.45
new_groupMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…