Final aggregation function of `slice_with_int_dask_array_on_axis`. Aggregate all chunks of x by one chunk of idx, reordering the output of `slice_with_int_dask_array`. Note that there is no combine function, as a recursive aggregation (e.g. with split_every) would not give any benef
(idx, chunk_outputs, x_chunks, axis)
| 350 | |
| 351 | |
| 352 | def slice_with_int_dask_array_aggregate(idx, chunk_outputs, x_chunks, axis): |
| 353 | """Final aggregation function of `slice_with_int_dask_array_on_axis`. |
| 354 | Aggregate all chunks of x by one chunk of idx, reordering the output of |
| 355 | `slice_with_int_dask_array`. |
| 356 | |
| 357 | Note that there is no combine function, as a recursive aggregation (e.g. |
| 358 | with split_every) would not give any benefit. |
| 359 | |
| 360 | Parameters |
| 361 | ---------- |
| 362 | idx: ndarray, ndim=1, dtype=any integer |
| 363 | j-th chunk of idx |
| 364 | chunk_outputs: ndarray |
| 365 | concatenation along axis of the outputs of `slice_with_int_dask_array` |
| 366 | for all chunks of x and the j-th chunk of idx |
| 367 | x_chunks: tuple |
| 368 | dask chunks of the x da.Array along axis, e.g. ``(3, 3, 2)`` |
| 369 | axis: int |
| 370 | normalized axis to take elements from (0 <= axis < x.ndim) |
| 371 | |
| 372 | Returns |
| 373 | ------- |
| 374 | Selection from all chunks of x for the j-th chunk of idx, in the correct |
| 375 | order |
| 376 | """ |
| 377 | # Needed when idx is unsigned |
| 378 | idx = idx.astype(np.int64) |
| 379 | |
| 380 | # Normalize negative indices |
| 381 | idx = np.where(idx < 0, idx + sum(x_chunks), idx) |
| 382 | |
| 383 | x_chunk_offset = 0 |
| 384 | chunk_output_offset = 0 |
| 385 | |
| 386 | # Assemble the final index that picks from the output of the previous |
| 387 | # kernel by adding together one layer per chunk of x |
| 388 | # FIXME: this could probably be reimplemented with a faster search-based |
| 389 | # algorithm |
| 390 | idx_final = np.zeros_like(idx) |
| 391 | for x_chunk in x_chunks: |
| 392 | idx_filter = (idx >= x_chunk_offset) & (idx < x_chunk_offset + x_chunk) |
| 393 | idx_cum = np.cumsum(idx_filter) |
| 394 | idx_final += np.where(idx_filter, idx_cum - 1 + chunk_output_offset, 0) |
| 395 | x_chunk_offset += x_chunk |
| 396 | if idx_cum.size > 0: |
| 397 | chunk_output_offset += idx_cum[-1] |
| 398 | |
| 399 | # np.take does not support slice indices |
| 400 | # return np.take(chunk_outputs, idx_final, axis) |
| 401 | return chunk_outputs[ |
| 402 | tuple( |
| 403 | idx_final if i == axis else slice(None) for i in range(chunk_outputs.ndim) |
| 404 | ) |
| 405 | ] |
| 406 | |
| 407 | |
| 408 | def getitem(obj, index): |