Dask-aware bottleneck.push
(array, n, axis, method="blelloch")
| 91 | |
| 92 | |
| 93 | def push(array, n, axis, method="blelloch"): |
| 94 | """ |
| 95 | Dask-aware bottleneck.push |
| 96 | """ |
| 97 | import dask.array as da |
| 98 | import numpy as np |
| 99 | |
| 100 | from xarray.core.duck_array_ops import _push |
| 101 | from xarray.core.nputils import nanlast |
| 102 | |
| 103 | if n is not None and all(n <= size for size in array.chunks[axis]): |
| 104 | return array.map_overlap(_push, depth={axis: (n, 0)}, n=n, axis=axis) |
| 105 | |
| 106 | # TODO: Replace all this function |
| 107 | # once https://github.com/pydata/xarray/issues/9229 being implemented |
| 108 | |
| 109 | pushed_array = da.reductions.cumreduction( |
| 110 | func=_dtype_push, |
| 111 | binop=_fill_with_last_one, |
| 112 | ident=np.nan, |
| 113 | x=array, |
| 114 | axis=axis, |
| 115 | dtype=array.dtype, |
| 116 | method=method, |
| 117 | preop=nanlast, |
| 118 | ) |
| 119 | |
| 120 | if n is not None and 0 < n < array.shape[axis] - 1: |
| 121 | # The idea is to calculate a cumulative sum of a bitmask |
| 122 | # created from the isnan method, but every time a False is found the sum |
| 123 | # must be restarted, and the final result indicates the amount of contiguous |
| 124 | # nan values found in the original array on every position |
| 125 | nan_bitmask = da.isnan(array, dtype=int) |
| 126 | cumsum_nan = nan_bitmask.cumsum(axis=axis, method=method) |
| 127 | valid_positions = da.where(nan_bitmask == 0, cumsum_nan, np.nan) |
| 128 | valid_positions = push(valid_positions, None, axis, method=method) |
| 129 | # All the NaNs at the beginning are converted to 0 |
| 130 | valid_positions = da.nan_to_num(valid_positions) |
| 131 | valid_positions = cumsum_nan - valid_positions |
| 132 | valid_positions = valid_positions <= n |
| 133 | pushed_array = da.where(valid_positions, pushed_array, np.nan) |
| 134 | |
| 135 | return pushed_array |