A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method. It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation. We use a ND-tuple
| 23 | |
| 24 | |
| 25 | class ProcessGroupMesh: |
| 26 | """A helper class to manage the process group mesh. It only describes how to organize process groups, and it's decoupled with parallel method. |
| 27 | It just initialize process groups and cache them. The parallel method should manage them and use them to do the parallel computation. |
| 28 | |
| 29 | We use a ND-tuple to represent the process group mesh. And a ND-coordinate is to represent each process. |
| 30 | For example, ``(0, 1, 0)`` represents the process whose rank is 2 in a 3D process group mesh with size ``(2, 2, 2)``. |
| 31 | |
| 32 | Args: |
| 33 | *size (int): The size of each dimension of the process group mesh. The product of the size must be equal to the world size. |
| 34 | |
| 35 | Attributes: |
| 36 | shape (Tuple[int, ...]): The shape of the process group mesh. |
| 37 | rank (int): The rank of the current process. |
| 38 | """ |
| 39 | |
| 40 | def __init__(self, *size: int) -> None: |
| 41 | assert dist.is_initialized(), "Please initialize torch.distributed first." |
| 42 | world_size = dist.get_world_size() |
| 43 | prod_size = prod(size) |
| 44 | assert ( |
| 45 | prod_size == world_size |
| 46 | ), f"The product of the size({prod_size}) must be equal to the world size({world_size})." |
| 47 | |
| 48 | self._shape = size |
| 49 | self._rank = dist.get_rank() |
| 50 | self._coord = ProcessGroupMesh.unravel(self._rank, self._shape) |
| 51 | self._ranks_to_group: Dict[Tuple[int, ...], Union[ProcessGroup, GroupMember.NON_GROUP_MEMBER]] = {} |
| 52 | self._group_to_ranks: Dict[ProcessGroup, Tuple[int, ...]] = {} |
| 53 | |
| 54 | def destroy_mesh_process_groups(self): |
| 55 | r""" |
| 56 | Destructor method for the ProcessGroupMesh class. |
| 57 | |
| 58 | When the ProcessGroupMesh object is deleted or goes out of scope, this method is called. It is responsible for |
| 59 | cleaning up any process groups that were created during the lifetime of the object. |
| 60 | |
| 61 | Note: |
| 62 | All process groups in PyTorch are represented as global variables, and they may not be automatically destroyed |
| 63 | when the ProcessGroupMesh's lifetime ends. This method manually destroys the process groups to release |
| 64 | system resources. |
| 65 | """ |
| 66 | for group in self._ranks_to_group.values(): |
| 67 | try: |
| 68 | dist.destroy_process_group(group) |
| 69 | except ValueError: |
| 70 | pass |
| 71 | |
| 72 | # Manually clear all process groups to save memory |
| 73 | gc.collect() |
| 74 | |
| 75 | @property |
| 76 | def shape(self) -> Tuple[int, ...]: |
| 77 | return self._shape |
| 78 | |
| 79 | @property |
| 80 | def rank(self) -> int: |
| 81 | return self._rank |
| 82 |
no outgoing calls
searching dependent graphs…