(expr)
| 21 | |
| 22 | |
| 23 | def create_array_collection(expr): |
| 24 | # This is hacky and an abstraction leak, but utilizing get_collection_type |
| 25 | # to infer that we want to create an array is the only way that is guaranteed |
| 26 | # to be a general solution. |
| 27 | # We can get rid of this when we have an Array expression |
| 28 | import dask.array as da |
| 29 | from dask.highlevelgraph import HighLevelGraph |
| 30 | from dask.layers import Blockwise |
| 31 | |
| 32 | if da._array_expr_enabled(): |
| 33 | from dask.array._array_expr._expr import ArrayExpr |
| 34 | |
| 35 | if isinstance(expr, ArrayExpr): |
| 36 | from dask.array._array_expr._collection import Array |
| 37 | |
| 38 | return Array(expr) |
| 39 | |
| 40 | result = expr.optimize() |
| 41 | dsk = result.__dask_graph__() |
| 42 | name = result._name |
| 43 | meta = result._meta |
| 44 | divisions = result.divisions |
| 45 | |
| 46 | chunks = ((np.nan,) * (len(divisions) - 1),) + tuple((d,) for d in meta.shape[1:]) |
| 47 | if len(chunks) > 1: |
| 48 | if isinstance(dsk, HighLevelGraph): |
| 49 | layer = dsk.layers[name] |
| 50 | else: |
| 51 | # dask-expr provides a dict only |
| 52 | layer = dsk |
| 53 | |
| 54 | new_keys = [] |
| 55 | if isinstance(layer, Blockwise): |
| 56 | layer.new_axes["j"] = chunks[1][0] |
| 57 | layer.output_indices = layer.output_indices + ("j",) |
| 58 | else: |
| 59 | from dask._task_spec import Alias, Task |
| 60 | |
| 61 | suffix = (0,) * (len(chunks) - 1) |
| 62 | for i in range(len(chunks[0])): |
| 63 | task = layer.get((name, i)) |
| 64 | new_key = (name, i) + suffix |
| 65 | if isinstance(task, Task): |
| 66 | task = Alias(new_key, task.key) |
| 67 | layer[new_key] = task |
| 68 | new_keys.append(new_key) |
| 69 | else: |
| 70 | new_keys = [(name, 0)] |
| 71 | if da._array_expr_enabled(): |
| 72 | from dask.array._array_expr._collection import from_graph |
| 73 | |
| 74 | return from_graph(dsk, meta, chunks, set(new_keys), name) |
| 75 | else: |
| 76 | return da.Array(dsk, name=name, chunks=chunks, dtype=meta.dtype) |
| 77 | |
| 78 | |
| 79 | @get_collection_type.register(np.ndarray) |
nothing calls this directly
no test coverage detected
searching dependent graphs…