Each process scatters list of input tensors to all processes in a cluster and return gathered list of tensors in output list. The tensors should have the same shape. Parameters ---------- rank : int The rank of current worker world_size : int The size of the
(rank, world_size, output_tensor_list, input_tensor_list)
| 60 | |
| 61 | |
| 62 | def __alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list): |
| 63 | """ |
| 64 | Each process scatters list of input tensors to all processes in a cluster |
| 65 | and return gathered list of tensors in output list. The tensors should have the same shape. |
| 66 | |
| 67 | Parameters |
| 68 | ---------- |
| 69 | rank : int |
| 70 | The rank of current worker |
| 71 | world_size : int |
| 72 | The size of the entire |
| 73 | output_tensor_list : List of tensor |
| 74 | The received tensors |
| 75 | input_tensor_list : List of tensor |
| 76 | The tensors to exchange |
| 77 | """ |
| 78 | input_tensor_list = [ |
| 79 | tensor.to(torch.device("cpu")) for tensor in input_tensor_list |
| 80 | ] |
| 81 | # TODO(#5002): As Boolean data is not supported in |
| 82 | # ``torch.distributed.scatter()``, we convert boolean into uint8 before |
| 83 | # scatter and convert it back afterwards. |
| 84 | dtypes = [t.dtype for t in input_tensor_list] |
| 85 | for i, dtype in enumerate(dtypes): |
| 86 | if dtype == torch.bool: |
| 87 | input_tensor_list[i] = input_tensor_list[i].to(torch.int8) |
| 88 | output_tensor_list[i] = output_tensor_list[i].to(torch.int8) |
| 89 | for i in range(world_size): |
| 90 | dist.scatter( |
| 91 | output_tensor_list[i], input_tensor_list if i == rank else [], src=i |
| 92 | ) |
| 93 | # Convert back to original dtype |
| 94 | for i, dtype in enumerate(dtypes): |
| 95 | if dtype == torch.bool: |
| 96 | input_tensor_list[i] = input_tensor_list[i].to(dtype) |
| 97 | output_tensor_list[i] = output_tensor_list[i].to(dtype) |
| 98 | |
| 99 | |
| 100 | def alltoallv_cpu(rank, world_size, input_tensor_list, retain_nones=True): |
no test coverage detected