(session_kind, ccl)
| 58 | @pytest.mark.parametrize("session_kind", _all_session_kinds) |
| 59 | @pytest.mark.parametrize("ccl", _ccl) |
| 60 | def test_allreduce(session_kind, ccl): |
| 61 | devices = [0, 1] |
| 62 | sess = session_kind(num_workers=len(devices)) |
| 63 | sess.init_ccl(ccl, *devices) |
| 64 | |
| 65 | array_1 = np.arange(12, dtype="float32").reshape(3, 4) |
| 66 | array_2 = np.arange(start=1, stop=-11, step=-1, dtype="float32").reshape(3, 4) |
| 67 | d_array = sess.empty((3, 4), "float32") |
| 68 | d_array.debug_copy_from(0, array_1) |
| 69 | d_array.debug_copy_from(1, array_2) |
| 70 | for op, np_op in [ # pylint: disable=invalid-name |
| 71 | ("sum", np.add), |
| 72 | ("prod", np.multiply), |
| 73 | ("min", np.minimum), |
| 74 | ("max", np.maximum), |
| 75 | ("avg", lambda a, b: (a + b) * 0.5), |
| 76 | ]: |
| 77 | dst_array = sess.empty((3, 4), "float32") |
| 78 | sess.allreduce(d_array, dst_array, op=op) |
| 79 | result = dst_array.debug_get_from_remote(0).numpy() |
| 80 | expected = np_op(array_1, array_2) |
| 81 | np.testing.assert_equal(result, expected) |
| 82 | |
| 83 | |
| 84 | @pytest.mark.parametrize("session_kind", _all_session_kinds) |
nothing calls this directly
no test coverage detected
searching dependent graphs…