Opposite of above function, gather values from model parallel ranks. Args: tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. Returns: :class:`torch.Tensor`: The gathered tensor.
(tensor: torch.Tensor)
| 109 | |
| 110 | |
| 111 | def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor: |
| 112 | """Opposite of above function, gather values from model parallel ranks. |
| 113 | |
| 114 | Args: |
| 115 | tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. |
| 116 | Returns: |
| 117 | :class:`torch.Tensor`: The gathered tensor. |
| 118 | """ |
| 119 | world_size = gpc.get_world_size(ParallelMode.TENSOR) |
| 120 | numel = torch.numel(tensor) |
| 121 | numel_gathered = world_size * numel |
| 122 | gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) |
| 123 | chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)] |
| 124 | dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR)) |
| 125 | return gathered |
no test coverage detected