Add an operation that performs a collective all-gather across the TP group. If 'sizes' is 'None', the input tensors in the different ranks must have the same shape. Otherwise, 'sizes[i]' must be 'input.shape[dim]' at rank i, and the input tensors in the different ranks can only dif
(
input: Union[torch.Tensor, List[torch.Tensor]],
mapping: Mapping,
dim: int = -1,
sizes: Optional[List[int]] = None,
)
| 265 | |
| 266 | |
| 267 | def allgather( |
| 268 | input: Union[torch.Tensor, List[torch.Tensor]], |
| 269 | mapping: Mapping, |
| 270 | dim: int = -1, |
| 271 | sizes: Optional[List[int]] = None, |
| 272 | ) -> Union[torch.Tensor, List[torch.Tensor]]: |
| 273 | ''' |
| 274 | Add an operation that performs a collective all-gather across the TP group. |
| 275 | |
| 276 | If 'sizes' is 'None', the input tensors in the different ranks must have the same shape. |
| 277 | Otherwise, 'sizes[i]' must be 'input.shape[dim]' at rank i, and the input tensors in |
| 278 | the different ranks can only differ in shape at dimension `dim`. |
| 279 | |
| 280 | The input tensors in the same TP group are concatenated at dimension 'dim' to produce the output tensor. |
| 281 | If 'sizes' is 'None', 'output.shape[dim] = input.shape[dim] * tp_group_size'. |
| 282 | Otherwise, 'output.shape[dim] = sum(sizes)'. |
| 283 | |
| 284 | That operation is implemented using a torch op that wraps the NCCL all-gather collective operation or |
| 285 | the NCCL group call of a series of NCCL broadcast collective operations. See the following materials for details. |
| 286 | https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather, |
| 287 | https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#broadcast, |
| 288 | https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/group.html. |
| 289 | |
| 290 | Args: |
| 291 | input (Union[Tensor, List[Tensor]]): The input tensor or tensor list. |
| 292 | mapping (Mapping): The parallel mapping. |
| 293 | dim (int): Gather along given dimension. By default -1. |
| 294 | sizes(Optional[List[int]]): An optional list indicating 'input.shape[dim]' in all ranks. By default None. |
| 295 | Returns: |
| 296 | The gathered tensor or tensor list. |
| 297 | ''' |
| 298 | group_boxed = mapping.tp_group_pg.boxed() if mpi_disabled() else None |
| 299 | return _allgather(input, mapping.tp_group, mapping.tp_rank, group_boxed, |
| 300 | dim, sizes) |
| 301 | |
| 302 | |
| 303 | def cp_allgather( |
no test coverage detected