MCPcopy
hub / github.com/Robbyant/lingbot-world / all_to_all

Function all_to_all

wan/distributed/util.py:20–30  ·  view source on GitHub ↗

`scatter` along one dimension and `gather` along another.

(x, scatter_dim, gather_dim, group=None, **kwargs)

Source from the content-addressed store, hash-verified

18
19
20def all_to_all(x, scatter_dim, gather_dim, group=None, **kwargs):
21 """
22 `scatter` along one dimension and `gather` along another.
23 """
24 world_size = get_world_size()
25 if world_size > 1:
26 inputs = [u.contiguous() for u in x.chunk(world_size, dim=scatter_dim)]
27 outputs = [torch.empty_like(u) for u in inputs]
28 dist.all_to_all(outputs, inputs, group=group, **kwargs)
29 x = torch.cat(outputs, dim=gather_dim).contiguous()
30 return x
31
32
33def all_gather(tensor):

Callers 2

distributed_attentionFunction · 0.85
sp_attn_forward_causalFunction · 0.85

Calls 1

get_world_sizeFunction · 0.85

Tested by

no test coverage detected