| 4874 | @requires_scipy |
| 4875 | @pytest.mark.parametrize("use_dask", [True, False]) |
| 4876 | def test_curvefit(self, use_dask) -> None: |
| 4877 | if use_dask and not has_dask: |
| 4878 | pytest.skip("requires dask") |
| 4879 | |
| 4880 | def exp_decay(t, n0, tau=1): |
| 4881 | return n0 * np.exp(-t / tau) |
| 4882 | |
| 4883 | t = np.arange(0, 5, 0.5) |
| 4884 | da = DataArray( |
| 4885 | np.stack([exp_decay(t, 3, 3), exp_decay(t, 5, 4), np.nan * t], axis=-1), |
| 4886 | dims=("t", "x"), |
| 4887 | coords={"t": t, "x": [0, 1, 2]}, |
| 4888 | ) |
| 4889 | da[0, 0] = np.nan |
| 4890 | |
| 4891 | expected = DataArray( |
| 4892 | [[3, 3], [5, 4], [np.nan, np.nan]], |
| 4893 | dims=("x", "param"), |
| 4894 | coords={"x": [0, 1, 2], "param": ["n0", "tau"]}, |
| 4895 | ) |
| 4896 | |
| 4897 | if use_dask: |
| 4898 | da = da.chunk({"x": 1}) |
| 4899 | |
| 4900 | fit = da.curvefit( |
| 4901 | coords=[da.t], func=exp_decay, p0={"n0": 4}, bounds={"tau": (2, 6)} |
| 4902 | ) |
| 4903 | assert_allclose(fit.curvefit_coefficients, expected, rtol=1e-3) |
| 4904 | |
| 4905 | da = da.compute() |
| 4906 | fit = da.curvefit(coords="t", func=np.power, reduce_dims="x", param_names=["a"]) |
| 4907 | assert "a" in fit.param |
| 4908 | assert "x" not in fit.dims |
| 4909 | |
| 4910 | def test_curvefit_helpers(self) -> None: |
| 4911 | def exp_decay(t, n0, tau=1): |