MCPcopy
hub / github.com/hpcaitech/ColossalAI / ProcessGroupMesh

Class ProcessGroupMesh

colossalai/cluster/process_group_mesh.py:25–276  ·  view source on GitHub ↗

A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method. It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation. We use a ND-tuple

Source from the content-addressed store, hash-verified

23
24
25class ProcessGroupMesh:
26 """A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method.
27 It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation.
28
29 We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process.
30 For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``.
31
32 Args:
33 *size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size.
34
35 Attributes:
36 shape (Tuple[int, ...]): The shape of the process group mesh.
37 rank (int): The rank of the current process.
38 """
39
40 def __init__(self, *size: int) -> None:
41 assert dist.is_initialized(), "Please initialize torch.distributed first."
42 world_size = dist.get_world_size()
43 prod_size = prod(size)
44 assert (
45 prod_size == world_size
46 ), f"The product of the size({prod_size}) must be equal to the world size({world_size})."
47
48 self._shape = size
49 self._rank = dist.get_rank()
50 self._coord = ProcessGroupMesh.unravel(self._rank, self._shape)
51 self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember.NON_GROUP_MEMBER]] = {}
52 self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {}
53
54 def destroy_mesh_process_groups(self):
55 r"""
56 Destructor method for the ProcessGroupMesh class.
57
58 When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for
59 cleaning up any process groups that were created during the lifetime of the object.
60
61 Note:
62 All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed
63 when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release
64 system resources.
65 """
66 for group in self._ranks_to_group.values():
67 try:
68 dist.destroy_process_group(group)
69 except ValueError:
70 pass
71
72 # Manually clear all process groups to save memory
73 gc.collect()
74
75 @property
76 def shape(self) -> Tuple[int, ...]:
77 return self._shape
78
79 @property
80 def rank(self) -> int:
81 return self._rank
82

Callers 15

mainFunction · 0.90
mainFunction · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
init_modelMethod · 0.90
init_modelMethod · 0.90
_init_modelMethod · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90
__init__Method · 0.90

Calls

no outgoing calls

Tested by 15

run_dist_lamb_basicFunction · 0.72
run_dist_lamb_fwd_bwdFunction · 0.72
exam_dist_came_baseFunction · 0.72
run_dist_galore_basicFunction · 0.72
run_dist_galore_fwd_bwdFunction · 0.72
exam_dist_adafactor_baseFunction · 0.72
exam_dist_adafactor_zeroFunction · 0.72
check_stage_managerFunction · 0.72
check_p2p_communicationFunction · 0.72
run_fwd_bwd_iter_inputFunction · 0.72
examine_ppFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…