MCPcopy
hub / github.com/dmlc/dgl / __alltoall_cpu

Function __alltoall_cpu

tools/distpartitioning/gloo_wrapper.py:62–97  ·  view source on GitHub ↗

Each process scatters list of input tensors to all processes in a cluster and return gathered list of tensors in output list. The tensors should have the same shape. Parameters ---------- rank : int The rank of current worker world_size : int The size of the

(rank, world_size, output_tensor_list, input_tensor_list)

Source from the content-addressed store, hash-verified

60
61
62def __alltoall_cpu(rank, world_size, output_tensor_list, input_tensor_list):
63 """
64 Each process scatters list of input tensors to all processes in a cluster
65 and return gathered list of tensors in output list. The tensors should have the same shape.
66
67 Parameters
68 ----------
69 rank : int
70 The rank of current worker
71 world_size : int
72 The size of the entire
73 output_tensor_list : List of tensor
74 The received tensors
75 input_tensor_list : List of tensor
76 The tensors to exchange
77 """
78 input_tensor_list = [
79 tensor.to(torch.device("cpu")) for tensor in input_tensor_list
80 ]
81 # TODO(#5002): As Boolean data is not supported in
82 # ``torch.distributed.scatter()``, we convert boolean into uint8 before
83 # scatter and convert it back afterwards.
84 dtypes = [t.dtype for t in input_tensor_list]
85 for i, dtype in enumerate(dtypes):
86 if dtype == torch.bool:
87 input_tensor_list[i] = input_tensor_list[i].to(torch.int8)
88 output_tensor_list[i] = output_tensor_list[i].to(torch.int8)
89 for i in range(world_size):
90 dist.scatter(
91 output_tensor_list[i], input_tensor_list if i == rank else [], src=i
92 )
93 # Convert back to original dtype
94 for i, dtype in enumerate(dtypes):
95 if dtype == torch.bool:
96 input_tensor_list[i] = input_tensor_list[i].to(dtype)
97 output_tensor_list[i] = output_tensor_list[i].to(dtype)
98
99
100def alltoallv_cpu(rank, world_size, input_tensor_list, retain_nones=True):

Callers 1

alltoallv_cpuFunction · 0.85

Calls 2

toMethod · 0.45
deviceMethod · 0.45

Tested by

no test coverage detected