(inshape, outshape, inchunks, disallow_dimension_expansion=False)
| 31 | |
| 32 | |
| 33 | def reshape_rechunk(inshape, outshape, inchunks, disallow_dimension_expansion=False): |
| 34 | assert all(isinstance(c, tuple) for c in inchunks) |
| 35 | ii = len(inshape) - 1 |
| 36 | oi = len(outshape) - 1 |
| 37 | result_inchunks = [None for i in range(len(inshape))] |
| 38 | result_outchunks = [None for i in range(len(outshape))] |
| 39 | mapper_in, one_dimensions = {}, [] |
| 40 | |
| 41 | while ii >= 0 or oi >= 0: |
| 42 | if inshape[ii] == outshape[oi]: |
| 43 | result_inchunks[ii] = inchunks[ii] |
| 44 | result_outchunks[oi] = inchunks[ii] |
| 45 | mapper_in[ii] = oi |
| 46 | ii -= 1 |
| 47 | oi -= 1 |
| 48 | continue |
| 49 | din = inshape[ii] |
| 50 | dout = outshape[oi] |
| 51 | if din == 1: |
| 52 | result_inchunks[ii] = (1,) |
| 53 | ii -= 1 |
| 54 | elif dout == 1: |
| 55 | result_outchunks[oi] = (1,) |
| 56 | one_dimensions.append(oi) |
| 57 | oi -= 1 |
| 58 | elif din < dout: # (4, 4, 4) -> (64,) |
| 59 | ileft = ii - 1 |
| 60 | mapper_in[ii] = oi |
| 61 | while ( |
| 62 | ileft >= 0 and reduce(mul, inshape[ileft : ii + 1]) < dout |
| 63 | ): # 4 < 64, 4*4 < 64, 4*4*4 == 64 |
| 64 | mapper_in[ileft] = oi |
| 65 | ileft -= 1 |
| 66 | |
| 67 | mapper_in[ileft] = oi |
| 68 | if reduce(mul, inshape[ileft : ii + 1]) != dout: |
| 69 | raise NotImplementedError(_not_implemented_message) |
| 70 | # Special case to avoid intermediate rechunking: |
| 71 | # When all the lower axis are completely chunked (chunksize=1) then |
| 72 | # we're simply moving around blocks. |
| 73 | if all(len(inchunks[i]) == inshape[i] for i in range(ii)): |
| 74 | for i in range(ii + 1): |
| 75 | result_inchunks[i] = inchunks[i] |
| 76 | result_outchunks[oi] = inchunks[ii] * math.prod( |
| 77 | map(len, inchunks[ileft:ii]) |
| 78 | ) |
| 79 | else: |
| 80 | for i in range(ileft + 1, ii + 1): # need single-shape dimensions |
| 81 | result_inchunks[i] = (inshape[i],) # chunks[i] = (4,) |
| 82 | |
| 83 | chunk_reduction = reduce(mul, map(len, inchunks[ileft + 1 : ii + 1])) |
| 84 | result_inchunks[ileft] = expand_tuple(inchunks[ileft], chunk_reduction) |
| 85 | |
| 86 | max_in_chunk = _cal_max_chunk_size(inchunks, ileft, ii) |
| 87 | result_inchunks = _smooth_chunks( |
| 88 | ileft, ii, max_in_chunk, result_inchunks |
| 89 | ) |
| 90 | # Build cross product of result_inchunks[ileft:ii+1] |
searching dependent graphs…