`scatter` along one dimension and `gather` along another.
(x, scatter_dim, gather_dim, group=None, **kwargs)
| 18 | |
| 19 | |
| 20 | def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs): |
| 21 | """ |
| 22 | `scatter` along one dimension and `gather` along another. |
| 23 | """ |
| 24 | world_size = get_world_size() |
| 25 | if world_size > 1: |
| 26 | inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)] |
| 27 | outputs = [torch.empty_like(u) for u in inputs] |
| 28 | dist.all_to_all(outputs, inputs, group=group, **kwargs) |
| 29 | x = torch.cat(outputs, dim=gather_dim).contiguous() |
| 30 | return x |
| 31 | |
| 32 | |
| 33 | def all_gather(tensor): |
no test coverage detected