MCPcopy
hub / github.com/InternLM/InternLM / gather_split_1d_tensor

Function gather_split_1d_tensor

internlm/core/communication/utils.py:111–125  ·  view source on GitHub ↗

Opposite of above function, gather values from model parallel ranks. Args: tensor (:class:`torch.Tensor`): Tensor to be gathered after communication. Returns: :class:`torch.Tensor`: The gathered tensor.

(tensor: torch.Tensor)

Source from the content-addressed store, hash-verified

109
110
111def gather_split_1d_tensor(tensor: torch.Tensor) -> torch.Tensor:
112 """Opposite of above function, gather values from model parallel ranks.
113
114 Args:
115 tensor (:class:`torch.Tensor`): Tensor to be gathered after communication.
116 Returns:
117 :class:`torch.Tensor`: The gathered tensor.
118 """
119 world_size = gpc.get_world_size(ParallelMode.TENSOR)
120 numel = torch.numel(tensor)
121 numel_gathered = world_size * numel
122 gathered = torch.empty(numel_gathered, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False)
123 chunks = [gathered[i * numel : (i + 1) * numel] for i in range(world_size)]
124 dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.TENSOR))
125 return gathered

Calls 3

get_world_sizeMethod · 0.80
emptyMethod · 0.80
get_groupMethod · 0.80

Tested by

no test coverage detected