MCPcopy
hub / github.com/InternLM/InternLM / _get_tensor_shape

Function _get_tensor_shape

internlm/core/communication/p2p.py:22–42  ·  view source on GitHub ↗

get the exact tensor shape when communicating and return whether the tensor is a chunk Args: tensor_shape (:class:`torch.Size`): shape of tensor chunk_tensor (bool, optional): whether to chunk tensor, defaults to False Returns: Tuple[Union[:class:`torch.Size`, List[

(tensor_shape: TensorShape, chunk_tensor: bool = False)

Source from the content-addressed store, hash-verified

20
21
22def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
23 """get the exact tensor shape when communicating and return whether the tensor is a chunk
24
25 Args:
26 tensor_shape (:class:`torch.Size`): shape of tensor
27 chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
28
29 Returns:
30 Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
31 """
32 if chunk_tensor:
33 tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
34 tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR)
35 if tensor_chunk_shape % tensor_parallel_world_size == 0:
36 tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size
37 else:
38 tensor_chunk_shape = tensor_shape
39 chunk_tensor = False
40 else:
41 tensor_chunk_shape = tensor_shape
42 return tensor_chunk_shape, chunk_tensor
43
44
45def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):

Callers 2

process_object_to_sendFunction · 0.85

Calls 1

get_world_sizeMethod · 0.80

Tested by

no test coverage detected