Receives obj meta information before receiving a specific obj. Since the recipient must know the shape of the obj in p2p communications, meta information of the obj should be received before communications. This function synchronizes with :func:`send_obj_meta`. Args: obj_sha
(prev_rank=None)
| 57 | |
| 58 | |
| 59 | def recv_obj_meta(prev_rank=None) -> torch.Size: |
| 60 | """Receives obj meta information before receiving a specific obj. |
| 61 | Since the recipient must know the shape of the obj in p2p communications, |
| 62 | meta information of the obj should be received before communications. This function |
| 63 | synchronizes with :func:`send_obj_meta`. |
| 64 | |
| 65 | Args: |
| 66 | obj_shape (Union[:class:`torch.Size`, List[:class:`torch.Size`]]): The shape of the obj to be received. |
| 67 | prev_rank (int): The rank of the source of the obj. |
| 68 | |
| 69 | Returns: |
| 70 | Union[:class:`torch.Size`, List[:class:`torch.Size`]]: The shape of the obj to be received. |
| 71 | """ |
| 72 | if prev_rank is None: |
| 73 | prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE) |
| 74 | |
| 75 | tensor_kwargs = {"dtype": torch.long, "device": get_current_device()} |
| 76 | recv_obj_nums = torch.empty((), **tensor_kwargs) |
| 77 | dist.recv(recv_obj_nums, prev_rank) |
| 78 | if recv_obj_nums.item() == 1: |
| 79 | recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) |
| 80 | obj_shape = torch.Size(recv_shape) |
| 81 | else: |
| 82 | obj_shape = [] |
| 83 | for _ in range(recv_obj_nums.item()): |
| 84 | recv_shape = recv_meta_helper(prev_rank, tensor_kwargs) |
| 85 | obj_shape.append(torch.Size(recv_shape)) |
| 86 | |
| 87 | return obj_shape |
| 88 | |
| 89 | |
| 90 | def split_tensor_into_1d_equal_chunks(tensor: torch.Tensor, new_buffer=False) -> torch.Tensor: |
nothing calls this directly
no test coverage detected