(input, dim)
| 40 | |
| 41 | |
| 42 | def gather_forward(input, dim): |
| 43 | # skip if world_size == 1 |
| 44 | world_size = dist.get_world_size() |
| 45 | if world_size == 1: |
| 46 | return input |
| 47 | |
| 48 | # gather sequence |
| 49 | output = all_gather(input) |
| 50 | return torch.cat(output, dim=dim).contiguous() |
no test coverage detected