| 2140 | |
| 2141 | @pytest.mark.parametrize("use_dask", [True, False]) |
| 2142 | def test_dot_align_coords(use_dask: bool) -> None: |
| 2143 | # GH 3694 |
| 2144 | |
| 2145 | if use_dask and not has_dask: |
| 2146 | pytest.skip("test for dask.") |
| 2147 | |
| 2148 | a = np.arange(30 * 4).reshape(30, 4) |
| 2149 | b = np.arange(30 * 4 * 5).reshape(30, 4, 5) |
| 2150 | |
| 2151 | # use partially overlapping coords |
| 2152 | coords_a = {"a": np.arange(30), "b": np.arange(4)} |
| 2153 | coords_b = {"a": np.arange(5, 35), "b": np.arange(1, 5)} |
| 2154 | |
| 2155 | da_a = xr.DataArray(a, dims=["a", "b"], coords=coords_a) |
| 2156 | da_b = xr.DataArray(b, dims=["a", "b", "c"], coords=coords_b) |
| 2157 | |
| 2158 | if use_dask: |
| 2159 | da_a = da_a.chunk({"a": 3}) |
| 2160 | da_b = da_b.chunk({"a": 3}) |
| 2161 | |
| 2162 | # join="inner" is the default |
| 2163 | actual = xr.dot(da_a, da_b) |
| 2164 | # `dot` sums over the common dimensions of the arguments |
| 2165 | expected = (da_a * da_b).sum(["a", "b"]) |
| 2166 | xr.testing.assert_allclose(expected, actual) |
| 2167 | |
| 2168 | actual = xr.dot(da_a, da_b, dim=...) |
| 2169 | expected = (da_a * da_b).sum() |
| 2170 | xr.testing.assert_allclose(expected, actual) |
| 2171 | |
| 2172 | with xr.set_options(arithmetic_join="exact"): |
| 2173 | with pytest.raises(ValueError, match=r"cannot align.*join.*exact.*not equal.*"): |
| 2174 | xr.dot(da_a, da_b) |
| 2175 | |
| 2176 | # NOTE: dot always uses `join="inner"` because `(a * b).sum()` yields the same for all |
| 2177 | # join method (except "exact") |
| 2178 | with xr.set_options(arithmetic_join="left"): |
| 2179 | actual = xr.dot(da_a, da_b) |
| 2180 | expected = (da_a * da_b).sum(["a", "b"]) |
| 2181 | xr.testing.assert_allclose(expected, actual) |
| 2182 | |
| 2183 | with xr.set_options(arithmetic_join="right"): |
| 2184 | actual = xr.dot(da_a, da_b) |
| 2185 | expected = (da_a * da_b).sum(["a", "b"]) |
| 2186 | xr.testing.assert_allclose(expected, actual) |
| 2187 | |
| 2188 | with xr.set_options(arithmetic_join="outer"): |
| 2189 | actual = xr.dot(da_a, da_b) |
| 2190 | expected = (da_a * da_b).sum(["a", "b"]) |
| 2191 | xr.testing.assert_allclose(expected, actual) |
| 2192 | |
| 2193 | |
| 2194 | def test_where() -> None: |