Point wise indexing with only NumPy Arrays.
(x, dict_indexes)
| 5852 | |
| 5853 | |
| 5854 | def _vindex_array(x, dict_indexes): |
| 5855 | """Point wise indexing with only NumPy Arrays.""" |
| 5856 | |
| 5857 | token = tokenize(x, dict_indexes) |
| 5858 | try: |
| 5859 | broadcast_shape = np.broadcast_shapes( |
| 5860 | *(arr.shape for arr in dict_indexes.values()) |
| 5861 | ) |
| 5862 | except ValueError as e: |
| 5863 | # note: error message exactly matches numpy |
| 5864 | shapes_str = " ".join(str(a.shape) for a in dict_indexes.values()) |
| 5865 | raise IndexError( |
| 5866 | "shape mismatch: indexing arrays could not be " |
| 5867 | f"broadcast together with shapes {shapes_str}" |
| 5868 | ) from e |
| 5869 | npoints = math.prod(broadcast_shape) |
| 5870 | axes = [i for i in range(x.ndim) if i in dict_indexes] |
| 5871 | |
| 5872 | def _subset_to_indexed_axes(iterable): |
| 5873 | for i, elem in enumerate(iterable): |
| 5874 | if i in axes: |
| 5875 | yield elem |
| 5876 | |
| 5877 | bounds2 = tuple( |
| 5878 | np.array(cached_cumsum(c, initial_zero=True)) |
| 5879 | for c in _subset_to_indexed_axes(x.chunks) |
| 5880 | ) |
| 5881 | axis = _get_axis(tuple(i if i in axes else None for i in range(x.ndim))) |
| 5882 | out_name = f"vindex-merge-{token}" |
| 5883 | |
| 5884 | # Now compute indices of each output element within each input block |
| 5885 | # The index is relative to the block, not the array. |
| 5886 | block_idxs = tuple( |
| 5887 | np.searchsorted(b, ind, side="right") - 1 |
| 5888 | for b, ind in zip(bounds2, dict_indexes.values()) |
| 5889 | ) |
| 5890 | starts = (b[i] for i, b in zip(block_idxs, bounds2)) |
| 5891 | inblock_idxs = [] |
| 5892 | for idx, start in zip(dict_indexes.values(), starts): |
| 5893 | a = idx - start |
| 5894 | if len(a) > 0: |
| 5895 | dtype = np.min_scalar_type(np.max(a, axis=None)) |
| 5896 | inblock_idxs.append(a.astype(dtype, copy=False)) |
| 5897 | else: |
| 5898 | inblock_idxs.append(a) |
| 5899 | |
| 5900 | inblock_idxs = np.broadcast_arrays(*inblock_idxs) |
| 5901 | |
| 5902 | chunks = [c for i, c in enumerate(x.chunks) if i not in axes] |
| 5903 | # determine number of points in one single output block. |
| 5904 | # Use the input chunk size to determine this. |
| 5905 | max_chunk_point_dimensions = reduce( |
| 5906 | mul, map(cached_max, _subset_to_indexed_axes(x.chunks)) |
| 5907 | ) |
| 5908 | |
| 5909 | n_chunks, remainder = divmod(npoints, max_chunk_point_dimensions) |
| 5910 | chunks.insert( |
| 5911 | 0, |
no test coverage detected
searching dependent graphs…