(self)
| 2693 | assert tree.chunksizes == original_chunksizes, "original tree was modified" |
| 2694 | |
| 2695 | def test_persist(self): |
| 2696 | ds1 = xr.Dataset({"a": ("x", np.arange(10))}) |
| 2697 | ds2 = xr.Dataset({"b": ("y", np.arange(5))}) |
| 2698 | ds3 = xr.Dataset({"c": ("z", np.arange(4))}) |
| 2699 | ds4 = xr.Dataset({"d": ("x", np.arange(-5, 5))}) |
| 2700 | |
| 2701 | def fn(x): |
| 2702 | return 2 * x |
| 2703 | |
| 2704 | expected = xr.DataTree.from_dict( |
| 2705 | { |
| 2706 | "/": fn(ds1).chunk({"x": 5}), |
| 2707 | "/group1": fn(ds2).chunk({"y": 3}), |
| 2708 | "/group2": fn(ds3).chunk({"z": 2}), |
| 2709 | "/group1/subgroup1": fn(ds4).chunk({"x": 5}), |
| 2710 | } |
| 2711 | ) |
| 2712 | # Add trivial second layer to the task graph, persist should reduce to one |
| 2713 | tree = xr.DataTree.from_dict( |
| 2714 | { |
| 2715 | "/": fn(ds1.chunk({"x": 5})), |
| 2716 | "/group1": fn(ds2.chunk({"y": 3})), |
| 2717 | "/group2": fn(ds3.chunk({"z": 2})), |
| 2718 | "/group1/subgroup1": fn(ds4.chunk({"x": 5})), |
| 2719 | } |
| 2720 | ) |
| 2721 | original_chunksizes = tree.chunksizes |
| 2722 | original_hlg_depths = { |
| 2723 | node.path: len(node.dataset.__dask_graph__().layers) |
| 2724 | for node in tree.subtree |
| 2725 | } |
| 2726 | |
| 2727 | actual = tree.persist() |
| 2728 | actual_hlg_depths = { |
| 2729 | node.path: len(node.dataset.__dask_graph__().layers) |
| 2730 | for node in actual.subtree |
| 2731 | } |
| 2732 | |
| 2733 | assert_identical(actual, expected) |
| 2734 | |
| 2735 | assert actual.chunksizes == original_chunksizes, "chunksizes were modified" |
| 2736 | assert tree.chunksizes == original_chunksizes, ( |
| 2737 | "original chunksizes were modified" |
| 2738 | ) |
| 2739 | assert all(d == 1 for d in actual_hlg_depths.values()), ( |
| 2740 | "unexpected dask graph depth" |
| 2741 | ) |
| 2742 | assert all(d == 2 for d in original_hlg_depths.values()), ( |
| 2743 | "original dask graph was modified" |
| 2744 | ) |
| 2745 | |
| 2746 | def test_chunk(self): |
| 2747 | ds1 = xr.Dataset({"a": ("x", np.arange(10))}) |
nothing calls this directly
no test coverage detected