Find the common block dimensions from the list of block dimensions Currently only implements the simplest possible heuristic: the common block-dimension is the only one that does not span fully span a dimension. This is a conservative choice that allows us to avoid potentially very
(blockdims)
| 4267 | |
| 4268 | |
| 4269 | def common_blockdim(blockdims): |
| 4270 | """Find the common block dimensions from the list of block dimensions |
| 4271 | |
| 4272 | Currently only implements the simplest possible heuristic: the common |
| 4273 | block-dimension is the only one that does not span fully span a dimension. |
| 4274 | This is a conservative choice that allows us to avoid potentially very |
| 4275 | expensive rechunking. |
| 4276 | |
| 4277 | Assumes that each element of the input block dimensions has all the same |
| 4278 | sum (i.e., that they correspond to dimensions of the same size). |
| 4279 | |
| 4280 | Examples |
| 4281 | -------- |
| 4282 | >>> common_blockdim([(3,), (2, 1)]) |
| 4283 | (2, 1) |
| 4284 | >>> common_blockdim([(1, 2), (2, 1)]) |
| 4285 | (1, 1, 1) |
| 4286 | >>> common_blockdim([(2, 2), (3, 1)]) # doctest: +SKIP |
| 4287 | Traceback (most recent call last): |
| 4288 | ... |
| 4289 | ValueError: Chunks do not align |
| 4290 | """ |
| 4291 | if not any(blockdims): |
| 4292 | return () |
| 4293 | non_trivial_dims = {d for d in blockdims if len(d) > 1} |
| 4294 | if len(non_trivial_dims) == 1: |
| 4295 | return first(non_trivial_dims) |
| 4296 | if len(non_trivial_dims) == 0: |
| 4297 | return max(blockdims, key=first) |
| 4298 | |
| 4299 | if np.isnan(sum(map(sum, blockdims))): |
| 4300 | raise ValueError( |
| 4301 | f"Arrays' chunk sizes ({blockdims}) are unknown.\n\n" |
| 4302 | "A possible solution:\n" |
| 4303 | " x.compute_chunk_sizes()" |
| 4304 | ) |
| 4305 | |
| 4306 | if len(set(map(sum, non_trivial_dims))) > 1: |
| 4307 | raise ValueError("Chunks do not add up to same value", blockdims) |
| 4308 | |
| 4309 | # We have multiple non-trivial chunks on this axis |
| 4310 | # e.g. (5, 2) and (4, 3) |
| 4311 | |
| 4312 | # We create a single chunk tuple with the same total length |
| 4313 | # that evenly divides both, e.g. (4, 1, 2) |
| 4314 | |
| 4315 | # To accomplish this we walk down all chunk tuples together, finding the |
| 4316 | # smallest element, adding it to the output, and subtracting it from all |
| 4317 | # other elements and remove the element itself. We stop once we have |
| 4318 | # burned through all of the chunk tuples. |
| 4319 | # For efficiency's sake we reverse the lists so that we can pop off the end |
| 4320 | rchunks = [list(ntd)[::-1] for ntd in non_trivial_dims] |
| 4321 | total = sum(first(non_trivial_dims)) |
| 4322 | i = 0 |
| 4323 | |
| 4324 | out = [] |
| 4325 | while i < total: |
| 4326 | m = min(c[-1] for c in rchunks) |
searching dependent graphs…