Compute the rechunk of *x* to the given *chunks*.
(x, chunks)
| 683 | |
| 684 | |
| 685 | def _compute_rechunk(x, chunks): |
| 686 | """Compute the rechunk of *x* to the given *chunks*.""" |
| 687 | if x.size == 0: |
| 688 | # Special case for empty array, as the algorithm below does not behave correctly |
| 689 | return empty(x.shape, chunks=chunks, dtype=x.dtype) |
| 690 | |
| 691 | ndim = x.ndim |
| 692 | crossed = intersect_chunks(x.chunks, chunks) |
| 693 | x2 = dict() |
| 694 | intermediates = dict() |
| 695 | token = tokenize(x, chunks) |
| 696 | merge_name = f"rechunk-merge-{token}" |
| 697 | split_name = f"rechunk-split-{token}" |
| 698 | split_name_suffixes = count() |
| 699 | |
| 700 | # Pre-allocate old block references, to allow reuse and reduce the |
| 701 | # graph's memory footprint a bit. |
| 702 | old_blocks = np.empty([len(c) for c in x.chunks], dtype="O") |
| 703 | for index in np.ndindex(old_blocks.shape): |
| 704 | old_blocks[index] = (x.name,) + index |
| 705 | |
| 706 | # Iterate over all new blocks |
| 707 | new_index = product(*(range(len(c)) for c in chunks)) |
| 708 | |
| 709 | for new_idx, cross1 in zip(new_index, crossed): |
| 710 | key = (merge_name,) + new_idx |
| 711 | old_block_indices = [[cr[i][0] for cr in cross1] for i in range(ndim)] |
| 712 | subdims1 = [len(set(old_block_indices[i])) for i in range(ndim)] |
| 713 | |
| 714 | rec_cat_arg = np.empty(subdims1, dtype="O") |
| 715 | rec_cat_arg_flat = rec_cat_arg.flat |
| 716 | |
| 717 | # Iterate over the old blocks required to build the new block |
| 718 | for rec_cat_index, ind_slices in enumerate(cross1): |
| 719 | old_block_index, slices = zip(*ind_slices) |
| 720 | name = (split_name, next(split_name_suffixes)) |
| 721 | old_index = old_blocks[old_block_index][1:] |
| 722 | if all( |
| 723 | slc.start == 0 and slc.stop == x.chunks[i][ind] |
| 724 | for i, (slc, ind) in enumerate(zip(slices, old_index)) |
| 725 | ): |
| 726 | rec_cat_arg_flat[rec_cat_index] = TaskRef(old_blocks[old_block_index]) |
| 727 | else: |
| 728 | intermediates[name] = Task( |
| 729 | name, getitem, TaskRef(old_blocks[old_block_index]), slices |
| 730 | ) |
| 731 | rec_cat_arg_flat[rec_cat_index] = TaskRef(name) |
| 732 | |
| 733 | assert rec_cat_index == rec_cat_arg.size - 1 |
| 734 | |
| 735 | # New block is formed by concatenation of sliced old blocks |
| 736 | if all(d == 1 for d in rec_cat_arg.shape): |
| 737 | x2[key] = Alias(key, rec_cat_arg.flat[0]) |
| 738 | else: |
| 739 | x2[key] = Task( |
| 740 | key, |
| 741 | concatenate_shaped, |
| 742 | parse_input(list(rec_cat_arg.flatten())), |
no test coverage detected
searching dependent graphs…