MCPcopy
hub / github.com/InternLM/InternLM / _gather

Function _gather

internlm/model/utils.py:44–60  ·  view source on GitHub ↗
(input_, parallel_mode, dim=-1)

Source from the content-addressed store, hash-verified

42
43
44def _gather(input_, parallel_mode, dim=-1):
45 # skip if only one rank involved
46 world_size = gpc.get_world_size(parallel_mode)
47 if world_size == 1:
48 return input_
49
50 # all gather
51 rank = gpc.get_local_rank(parallel_mode)
52 tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
53 tensor_list[rank] = input_
54 group = gpc.get_cpu_group(parallel_mode) if input_.device.type == "cpu" else gpc.get_group(parallel_mode)
55 torch.distributed.all_gather(tensor_list, input_, group=group)
56
57 # concat
58 output = torch.cat(tensor_list, dim=dim).contiguous()
59
60 return output
61
62
63class _GatherForwardSplitBackward(torch.autograd.Function):

Callers 3

symbolicMethod · 0.85
forwardMethod · 0.85
backwardMethod · 0.85

Calls 4

get_world_sizeMethod · 0.80
get_local_rankMethod · 0.80
get_cpu_groupMethod · 0.80
get_groupMethod · 0.80

Tested by

no test coverage detected