(input_, parallel_mode, dim=-1)
| 42 | |
| 43 | |
| 44 | def _gather(input_, parallel_mode, dim=-1): |
| 45 | # skip if only one rank involved |
| 46 | world_size = gpc.get_world_size(parallel_mode) |
| 47 | if world_size == 1: |
| 48 | return input_ |
| 49 | |
| 50 | # all gather |
| 51 | rank = gpc.get_local_rank(parallel_mode) |
| 52 | tensor_list = [torch.empty_like(input_) for _ in range(world_size)] |
| 53 | tensor_list[rank] = input_ |
| 54 | group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode) |
| 55 | torch.distributed.all_gather(tensor_list, input_, group=group) |
| 56 | |
| 57 | # concat |
| 58 | output = torch.cat(tensor_list, dim=dim).contiguous() |
| 59 | |
| 60 | return output |
| 61 | |
| 62 | |
| 63 | class _GatherForwardSplitBackward(torch.autograd.Function): |
no test coverage detected