get the exact tensor shape when communicating and return whether the tensor is a chunk Args: tensor_shape (:class:`torch.Size`): shape of tensor chunk_tensor (bool, optional): whether to chunk tensor, defaults to False Returns: Tuple[Union[:class:`torch.Size`, List[
(tensor_shape: TensorShape, chunk_tensor: bool = False)
| 20 | |
| 21 | |
| 22 | def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]: |
| 23 | """get the exact tensor shape when communicating and return whether the tensor is a chunk |
| 24 | |
| 25 | Args: |
| 26 | tensor_shape (:class:`torch.Size`): shape of tensor |
| 27 | chunk_tensor (bool, optional): whether to chunk tensor, defaults to False |
| 28 | |
| 29 | Returns: |
| 30 | Tuple[Union[:class:`torch.Size`, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor |
| 31 | """ |
| 32 | if chunk_tensor: |
| 33 | tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) |
| 34 | tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR) |
| 35 | if tensor_chunk_shape % tensor_parallel_world_size == 0: |
| 36 | tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size |
| 37 | else: |
| 38 | tensor_chunk_shape = tensor_shape |
| 39 | chunk_tensor = False |
| 40 | else: |
| 41 | tensor_chunk_shape = tensor_shape |
| 42 | return tensor_chunk_shape, chunk_tensor |
| 43 | |
| 44 | |
| 45 | def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): |
no test coverage detected