(recv_shapes, dtype, scatter_gather_tensors)
| 43 | |
| 44 | |
| 45 | def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors): |
| 46 | if isinstance(recv_shapes, torch.Size): |
| 47 | recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors) |
| 48 | buffer_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) |
| 49 | return buffer_recv, recv_split |
| 50 | buffer_recv = [] |
| 51 | for recv_shape in recv_shapes: |
| 52 | recv_chunk_shape, recv_split = _get_tensor_shape(recv_shape, scatter_gather_tensors) |
| 53 | tensor_recv = torch.empty(recv_chunk_shape, requires_grad=True, device=get_current_device(), dtype=dtype) |
| 54 | buffer_recv.append(tensor_recv) |
| 55 | return buffer_recv, recv_split |
| 56 | |
| 57 | |
| 58 | def process_object_to_send(object_send, scatter_gather_tensors): |
no test coverage detected