Compute the rechunk of *x* to the given *chunks*.
(old_name, old_chunks, chunks, level, name)
| 133 | |
| 134 | |
| 135 | def _compute_rechunk(old_name, old_chunks, chunks, level, name): |
| 136 | """Compute the rechunk of *x* to the given *chunks*.""" |
| 137 | # TODO: redo this logic |
| 138 | # if x.size == 0: |
| 139 | # # Special case for empty array, as the algorithm below does not behave correctly |
| 140 | # return empty(x.shape, chunks=chunks, dtype=x.dtype) |
| 141 | |
| 142 | ndim = len(old_chunks) |
| 143 | crossed = intersect_chunks(old_chunks, chunks) |
| 144 | x2 = dict() |
| 145 | intermediates = dict() |
| 146 | # token = tokenize(old_name, chunks) |
| 147 | if level != 0: |
| 148 | merge_name = name.replace("rechunk-merge-", f"rechunk-merge-{level}-") |
| 149 | split_name = name.replace("rechunk-merge-", f"rechunk-split-{level}-") |
| 150 | else: |
| 151 | merge_name = name.replace("rechunk-merge-", "rechunk-merge-") |
| 152 | split_name = name.replace("rechunk-merge-", "rechunk-split-") |
| 153 | split_name_suffixes = itertools.count() |
| 154 | |
| 155 | # Pre-allocate old block references, to allow reuse and reduce the |
| 156 | # graph's memory footprint a bit. |
| 157 | old_blocks = np.empty([len(c) for c in old_chunks], dtype="O") |
| 158 | for index in np.ndindex(old_blocks.shape): |
| 159 | old_blocks[index] = (old_name,) + index |
| 160 | |
| 161 | # Iterate over all new blocks |
| 162 | new_index = itertools.product(*(range(len(c)) for c in chunks)) |
| 163 | |
| 164 | for new_idx, cross1 in zip(new_index, crossed): |
| 165 | key = (merge_name,) + new_idx |
| 166 | old_block_indices = [[cr[i][0] for cr in cross1] for i in range(ndim)] |
| 167 | subdims1 = [len(set(old_block_indices[i])) for i in range(ndim)] |
| 168 | |
| 169 | rec_cat_arg = np.empty(subdims1, dtype="O") |
| 170 | rec_cat_arg_flat = rec_cat_arg.flat |
| 171 | |
| 172 | # Iterate over the old blocks required to build the new block |
| 173 | for rec_cat_index, ind_slices in enumerate(cross1): |
| 174 | old_block_index, slices = zip(*ind_slices) |
| 175 | name = (split_name, next(split_name_suffixes)) |
| 176 | old_index = old_blocks[old_block_index][1:] |
| 177 | if all( |
| 178 | slc.start == 0 and slc.stop == old_chunks[i][ind] |
| 179 | for i, (slc, ind) in enumerate(zip(slices, old_index)) |
| 180 | ): |
| 181 | rec_cat_arg_flat[rec_cat_index] = old_blocks[old_block_index] |
| 182 | else: |
| 183 | intermediates[name] = ( |
| 184 | operator.getitem, |
| 185 | old_blocks[old_block_index], |
| 186 | slices, |
| 187 | ) |
| 188 | rec_cat_arg_flat[rec_cat_index] = name |
| 189 | |
| 190 | assert rec_cat_index == rec_cat_arg.size - 1 |
| 191 | |
| 192 | # New block is formed by concatenation of sliced old blocks |