MCPcopy Index your code
hub / github.com/apache/tvm / test_allreduce

Function test_allreduce

tests/python/disco/test_custom_allreduce.py:57–81  ·  view source on GitHub ↗
(shape, ccl, strategy)

Source from the content-addressed store, hash-verified

55@pytest.mark.parametrize("ccl", _ccl)
56@pytest.mark.parametrize("strategy", _strategies)
57def test_allreduce(shape, ccl, strategy):
58 devices = [0, 1]
59 sess: Session = disco.ProcessSession(num_workers=len(devices))
60 sess.init_ccl(ccl, *devices)
61
62 num_elements = reduce(lambda x, y: x * y, shape)
63 dtype = "float32"
64 falloc_ipc_storage = sess.get_global_func("runtime.disco.cuda_ipc.alloc_storage")
65 falloc_tensor = sess.get_global_func("vm.builtin.alloc_tensor")
66 fallreduce = sess.get_global_func("runtime.disco.cuda_ipc.custom_allreduce")
67 d_storage = sess.call_packed(falloc_ipc_storage, Shape(shape), DataType(dtype))
68 d_input = sess.call_packed(falloc_tensor, d_storage, 0, Shape(shape), DataType(dtype))
69
70 array_1 = np.arange(num_elements, dtype="float32").reshape(*shape)
71 array_2 = np.arange(start=1, stop=-(num_elements - 1), step=-1, dtype="float32").reshape(*shape)
72 d_input.debug_copy_from(0, array_1)
73 d_input.debug_copy_from(1, array_2)
74 d_output = sess.empty(shape, "float32")
75
76 sess.call_packed(fallreduce, d_input, strategy, d_output)
77 result_1 = d_output.debug_get_from_remote(0).numpy()
78 result_2 = d_output.debug_get_from_remote(1).numpy()
79 expected = np.add(array_1, array_2)
80 np.testing.assert_equal(result_1, expected)
81 np.testing.assert_equal(result_2, expected)
82
83
84if __name__ == "__main__":

Callers 1

Calls 13

reduceFunction · 0.85
DataTypeClass · 0.85
init_cclMethod · 0.80
get_global_funcMethod · 0.80
call_packedMethod · 0.80
debug_copy_fromMethod · 0.80
numpyMethod · 0.80
debug_get_from_remoteMethod · 0.80
ShapeClass · 0.50
reshapeMethod · 0.45
arangeMethod · 0.45
emptyMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…