(session_kind, ccl)
| 84 | @pytest.mark.parametrize("session_kind", _all_session_kinds) |
| 85 | @pytest.mark.parametrize("ccl", _ccl) |
| 86 | def test_group_allreduce(session_kind, ccl): |
| 87 | devices = [0, 1, 2, 3] |
| 88 | sess = session_kind(num_workers=len(devices), num_groups=2) |
| 89 | sess.init_ccl(ccl, *devices) |
| 90 | |
| 91 | array_1 = np.arange(12, dtype="float32").reshape(3, 4) |
| 92 | array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4) |
| 93 | array_3 = np.arange(30, dtype="float32").reshape(5, 6) |
| 94 | array_4 = np.arange(start=1, stop=-29, step=-1, dtype="float32").reshape(5, 6) |
| 95 | d_array_1 = sess.empty((3, 4), "float32") |
| 96 | d_array_2 = sess.empty((5, 6), "float32") |
| 97 | d_array_1.debug_copy_from(0, array_1) |
| 98 | d_array_1.debug_copy_from(1, array_2) |
| 99 | d_array_2.debug_copy_from(2, array_3) |
| 100 | d_array_2.debug_copy_from(3, array_4) |
| 101 | for op, np_op in [ # pylint: disable=invalid-name |
| 102 | ("sum", np.add), |
| 103 | ("prod", np.multiply), |
| 104 | ("min", np.minimum), |
| 105 | ("max", np.maximum), |
| 106 | ("avg", lambda a, b: (a + b) * 0.5), |
| 107 | ]: |
| 108 | dst_array_1 = sess.empty((3, 4), "float32") |
| 109 | dst_array_2 = sess.empty((5, 6), "float32") |
| 110 | sess.allreduce(d_array_1, dst_array_1, op=op, in_group=True) |
| 111 | sess.allreduce(d_array_2, dst_array_2, op=op, in_group=True) |
| 112 | result_1 = dst_array_1.debug_get_from_remote(0).numpy() |
| 113 | result_2 = dst_array_2.debug_get_from_remote(2).numpy() |
| 114 | expected_1 = np_op(array_1, array_2) |
| 115 | expected_2 = np_op(array_3, array_4) |
| 116 | np.testing.assert_equal(result_1, expected_1) |
| 117 | np.testing.assert_equal(result_2, expected_2) |
| 118 | |
| 119 | |
| 120 | @pytest.mark.parametrize("session_kind", _all_session_kinds) |
nothing calls this directly
no test coverage detected
searching dependent graphs…