Slice a ND dask array with a 1D dask arrays of ints along the given axis. This is a helper function of :func:`slice_with_int_dask_array`.
(x, idx, axis)
| 990 | |
| 991 | |
| 992 | def slice_with_int_dask_array_on_axis(x, idx, axis): |
| 993 | """Slice a ND dask array with a 1D dask arrays of ints along the given |
| 994 | axis. |
| 995 | |
| 996 | This is a helper function of :func:`slice_with_int_dask_array`. |
| 997 | """ |
| 998 | from dask.array import chunk |
| 999 | from dask.array.core import Array, blockwise, from_array |
| 1000 | from dask.array.utils import asarray_safe |
| 1001 | |
| 1002 | assert 0 <= axis < x.ndim |
| 1003 | |
| 1004 | if np.isnan(x.chunks[axis]).any(): |
| 1005 | raise NotImplementedError( |
| 1006 | "Slicing an array with unknown chunks with " |
| 1007 | "a dask.array of ints is not supported" |
| 1008 | ) |
| 1009 | |
| 1010 | # Calculate the offset at which each chunk starts along axis |
| 1011 | # e.g. chunks=(..., (5, 3, 4), ...) -> offset=[0, 5, 8] |
| 1012 | offset = np.roll(np.cumsum(asarray_safe(x.chunks[axis], like=x._meta)), 1) |
| 1013 | offset[0] = 0 |
| 1014 | offset = from_array(offset, chunks=1) |
| 1015 | # Tamper with the declared chunks of offset to make blockwise align it with |
| 1016 | # x[axis] |
| 1017 | offset = Array( |
| 1018 | offset.dask, offset.name, (x.chunks[axis],), offset.dtype, meta=x._meta |
| 1019 | ) |
| 1020 | |
| 1021 | # Define axis labels for blockwise |
| 1022 | x_axes = tuple(range(x.ndim)) |
| 1023 | idx_axes = (x.ndim,) # arbitrary index not already in x_axes |
| 1024 | offset_axes = (axis,) |
| 1025 | p_axes = x_axes[: axis + 1] + idx_axes + x_axes[axis + 1 :] |
| 1026 | y_axes = x_axes[:axis] + idx_axes + x_axes[axis + 1 :] |
| 1027 | |
| 1028 | # Calculate the cartesian product of every chunk of x vs every chunk of idx |
| 1029 | p = blockwise( |
| 1030 | chunk.slice_with_int_dask_array, |
| 1031 | p_axes, |
| 1032 | x, |
| 1033 | x_axes, |
| 1034 | idx, |
| 1035 | idx_axes, |
| 1036 | offset, |
| 1037 | offset_axes, |
| 1038 | x_size=x.shape[axis], |
| 1039 | axis=axis, |
| 1040 | dtype=x.dtype, |
| 1041 | meta=x._meta, |
| 1042 | ) |
| 1043 | |
| 1044 | # Aggregate on the chunks of x along axis |
| 1045 | y = blockwise( |
| 1046 | chunk.slice_with_int_dask_array_aggregate, |
| 1047 | y_axes, |
| 1048 | idx, |
| 1049 | idx_axes, |
no test coverage detected
searching dependent graphs…