MCPcopy Index your code
hub / github.com/apache/tvm / test_group_allreduce

Function test_group_allreduce

tests/python/disco/test_ccl.py:86–117  ·  view source on GitHub ↗
(session_kind, ccl)

Source from the content-addressed store, hash-verified

84@pytest.mark.parametrize("session_kind", _all_session_kinds)
85@pytest.mark.parametrize("ccl", _ccl)
86def test_group_allreduce(session_kind, ccl):
87 devices = [0, 1, 2, 3]
88 sess = session_kind(num_workers=len(devices), num_groups=2)
89 sess.init_ccl(ccl, *devices)
90
91 array_1 = np.arange(12, dtype="float32").reshape(3, 4)
92 array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4)
93 array_3 = np.arange(30, dtype="float32").reshape(5, 6)
94 array_4 = np.arange(start=1, stop=-29, step=-1, dtype="float32").reshape(5, 6)
95 d_array_1 = sess.empty((3, 4), "float32")
96 d_array_2 = sess.empty((5, 6), "float32")
97 d_array_1.debug_copy_from(0, array_1)
98 d_array_1.debug_copy_from(1, array_2)
99 d_array_2.debug_copy_from(2, array_3)
100 d_array_2.debug_copy_from(3, array_4)
101 for op, np_op in [ # pylint: disable=invalid-name
102 ("sum", np.add),
103 ("prod", np.multiply),
104 ("min", np.minimum),
105 ("max", np.maximum),
106 ("avg", lambda a, b: (a + b) * 0.5),
107 ]:
108 dst_array_1 = sess.empty((3, 4), "float32")
109 dst_array_2 = sess.empty((5, 6), "float32")
110 sess.allreduce(d_array_1, dst_array_1, op=op, in_group=True)
111 sess.allreduce(d_array_2, dst_array_2, op=op, in_group=True)
112 result_1 = dst_array_1.debug_get_from_remote(0).numpy()
113 result_2 = dst_array_2.debug_get_from_remote(2).numpy()
114 expected_1 = np_op(array_1, array_2)
115 expected_2 = np_op(array_3, array_4)
116 np.testing.assert_equal(result_1, expected_1)
117 np.testing.assert_equal(result_2, expected_2)
118
119
120@pytest.mark.parametrize("session_kind", _all_session_kinds)

Callers

nothing calls this directly

Calls 8

init_cclMethod · 0.80
debug_copy_fromMethod · 0.80
allreduceMethod · 0.80
numpyMethod · 0.80
debug_get_from_remoteMethod · 0.80
reshapeMethod · 0.45
arangeMethod · 0.45
emptyMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…