| 82 | @pytest.mark.parametrize("T_A, T_B", test_data) |
| 83 | @pytest.mark.parametrize("k", kNN) |
| 84 | def test_mparray_A_B_join(T_A, T_B, k): |
| 85 | m = 3 |
| 86 | ref_mp = naive.stump(T_A, m, T_B=T_B, k=k) |
| 87 | comp_mp = stump(T_A, m, T_B, ignore_trivial=False, k=k) |
| 88 | naive.replace_inf(ref_mp) |
| 89 | naive.replace_inf(comp_mp) |
| 90 | npt.assert_almost_equal(np.squeeze(ref_mp[:, :k]), comp_mp.P_) |
| 91 | npt.assert_almost_equal(np.squeeze(ref_mp[:, k : 2 * k]), comp_mp.I_) |
| 92 | npt.assert_almost_equal(np.squeeze(ref_mp[:, 2 * k]), comp_mp.left_I_) |
| 93 | npt.assert_almost_equal(np.squeeze(ref_mp[:, 2 * k + 1]), comp_mp.right_I_) |
| 94 | |
| 95 | comp_mp = stump(pd.Series(T_A), m, pd.Series(T_B), ignore_trivial=False, k=k) |
| 96 | naive.replace_inf(comp_mp) |
| 97 | npt.assert_almost_equal(np.squeeze(ref_mp[:, :k]), comp_mp.P_) |
| 98 | npt.assert_almost_equal(np.squeeze(ref_mp[:, k : 2 * k]), comp_mp.I_) |
| 99 | npt.assert_almost_equal(np.squeeze(ref_mp[:, 2 * k]), comp_mp.left_I_) |
| 100 | npt.assert_almost_equal(np.squeeze(ref_mp[:, 2 * k + 1]), comp_mp.right_I_) |
| 101 | |
| 102 | ref_mp = naive.aamp(T_A, m, T_B=T_B, k=k) |
| 103 | comp_mp = aamp(T_A, m, T_B, ignore_trivial=False, k=k) |
| 104 | naive.replace_inf(ref_mp) |
| 105 | naive.replace_inf(comp_mp) |
| 106 | npt.assert_almost_equal(np.squeeze(ref_mp[:, :k]), comp_mp.P_) |
| 107 | npt.assert_almost_equal(np.squeeze(ref_mp[:, k : 2 * k]), comp_mp.I_) |
| 108 | npt.assert_almost_equal(np.squeeze(ref_mp[:, 2 * k]), comp_mp.left_I_) |
| 109 | npt.assert_almost_equal(np.squeeze(ref_mp[:, 2 * k + 1]), comp_mp.right_I_) |
| 110 | |
| 111 | comp_mp = aamp(pd.Series(T_A), m, pd.Series(T_B), ignore_trivial=False, k=k) |
| 112 | naive.replace_inf(comp_mp) |
| 113 | npt.assert_almost_equal(np.squeeze(ref_mp[:, :k]), comp_mp.P_) |
| 114 | npt.assert_almost_equal(np.squeeze(ref_mp[:, k : 2 * k]), comp_mp.I_) |
| 115 | npt.assert_almost_equal(np.squeeze(ref_mp[:, 2 * k]), comp_mp.left_I_) |
| 116 | npt.assert_almost_equal(np.squeeze(ref_mp[:, 2 * k + 1]), comp_mp.right_I_) |