Adapted from megatron.p2p_communication. Communicate tensors between stages. Used as helper method in other communication methods that are used in pipeline schedule. Takes the following arguments: object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]):
(
object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None,
object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None,
recv_prev: bool = False,
recv_next: bool = False,
recv_prev_shape: Union[torch.Size, List[torch.Size]] = None,
recv_next_shape: Union[torch.Size, List[torch.Size]] = None,
prev_rank: int = None,
next_rank: int = None,
dtype: torch.dtype = None,
scatter_gather_tensors: bool = False,
)
| 85 | |
| 86 | |
| 87 | def _communicate( |
| 88 | object_send_next: Union[torch.Tensor, List[torch.Tensor]] = None, |
| 89 | object_send_prev: Union[torch.Tensor, List[torch.Tensor]] = None, |
| 90 | recv_prev: bool = False, |
| 91 | recv_next: bool = False, |
| 92 | recv_prev_shape: Union[torch.Size, List[torch.Size]] = None, |
| 93 | recv_next_shape: Union[torch.Size, List[torch.Size]] = None, |
| 94 | prev_rank: int = None, |
| 95 | next_rank: int = None, |
| 96 | dtype: torch.dtype = None, |
| 97 | scatter_gather_tensors: bool = False, |
| 98 | ) -> Tuple[Union[torch.Tensor, List[torch.Tensor]]]: |
| 99 | """ |
| 100 | Adapted from megatron.p2p_communication. |
| 101 | Communicate tensors between stages. Used as helper method in other |
| 102 | communication methods that are used in pipeline schedule. |
| 103 | Takes the following arguments: |
| 104 | object_send_next (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to next rank |
| 105 | (no tensor sent if set to None). |
| 106 | object_send_prev (Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]): tensor to send to prev rank |
| 107 | (no tensor sent if set to None). |
| 108 | recv_prev (bool): boolean for whether tensor should be received from |
| 109 | previous rank. |
| 110 | recv_next (bool): boolean for whether tensor should be received from |
| 111 | next rank. |
| 112 | recv_prev_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received |
| 113 | from the previous stage, defualts to None. |
| 114 | recv_next_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): shape of the tensor to be received |
| 115 | from the next stage, defualts to None. |
| 116 | prev_rank (int): the rank of the previous pipeline stage, defualts to None, |
| 117 | next_rank (int): the rank of the next pipeline stage, defualts to None, |
| 118 | dtype (torch.dtype): data type of intermediate buffers, defaults to None |
| 119 | scatter_gather_tensors (bool): whether to scatter and gather tensor between pipeline stages, defaults to False |
| 120 | |
| 121 | Returns: |
| 122 | Tuple[Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]]: returns tensor_recv_prev, tensor_recv_next |
| 123 | """ |
| 124 | |
| 125 | # Create placeholder tensors for receive in forward and backward directions |
| 126 | # if needed. |
| 127 | tensor_recv_prev = None |
| 128 | tensor_recv_next = None |
| 129 | |
| 130 | if recv_prev: |
| 131 | assert recv_prev_shape is not None |
| 132 | tensor_recv_prev, recv_prev_split = create_recv_buffer_with_shapes( |
| 133 | recv_prev_shape, dtype, scatter_gather_tensors |
| 134 | ) |
| 135 | |
| 136 | if recv_next: |
| 137 | assert recv_next_shape is not None |
| 138 | tensor_recv_next, recv_next_split = create_recv_buffer_with_shapes( |
| 139 | recv_next_shape, dtype, scatter_gather_tensors |
| 140 | ) |
| 141 | |
| 142 | if object_send_prev is not None or recv_prev: |
| 143 | if prev_rank is None: |
| 144 | prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) |
no test coverage detected