Rewrite a stack of Blockwise expressions into a single blockwise expression Given a set of Blockwise layers, combine them into a single layer. The provided layers are expected to fit well together. That job is handled by ``optimize_blockwise`` Parameters ---------- inputs
(inputs)
| 1260 | |
| 1261 | |
| 1262 | def rewrite_blockwise(inputs): |
| 1263 | """Rewrite a stack of Blockwise expressions into a single blockwise expression |
| 1264 | |
| 1265 | Given a set of Blockwise layers, combine them into a single layer. The provided |
| 1266 | layers are expected to fit well together. That job is handled by |
| 1267 | ``optimize_blockwise`` |
| 1268 | |
| 1269 | Parameters |
| 1270 | ---------- |
| 1271 | inputs : list[Blockwise] |
| 1272 | |
| 1273 | Returns |
| 1274 | ------- |
| 1275 | blockwise: Blockwise |
| 1276 | |
| 1277 | See Also |
| 1278 | -------- |
| 1279 | optimize_blockwise |
| 1280 | """ |
| 1281 | if len(inputs) == 1: |
| 1282 | # Fast path: if there's only one input we can just use it as-is. |
| 1283 | return inputs[0] |
| 1284 | |
| 1285 | fused_annotations = _fuse_annotations( |
| 1286 | *[i.annotations for i in inputs if i.annotations] |
| 1287 | ) |
| 1288 | inputs = {inp.output: inp for inp in inputs} |
| 1289 | dependencies = { |
| 1290 | inp.output: {d for d, v in inp.indices if v is not None and d in inputs} |
| 1291 | for inp in inputs.values() |
| 1292 | } |
| 1293 | dependents = reverse_dict(dependencies) |
| 1294 | |
| 1295 | new_index_iter = ( |
| 1296 | c + (str(d) if d else "") # A, B, ... A1, B1, ... |
| 1297 | for d in itertools.count() |
| 1298 | for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ" |
| 1299 | ) |
| 1300 | |
| 1301 | [root] = [k for k, v in dependents.items() if not v] |
| 1302 | |
| 1303 | # Our final results. These will change during fusion below |
| 1304 | indices = list(inputs[root].indices) |
| 1305 | new_axes = inputs[root].new_axes |
| 1306 | concatenate = inputs[root].concatenate |
| 1307 | task = inputs[root].task |
| 1308 | dsk = {task.key: task} |
| 1309 | |
| 1310 | changed = True |
| 1311 | while changed: |
| 1312 | changed = False |
| 1313 | for i, (dep, current_dep_indices) in enumerate(indices): |
| 1314 | if current_dep_indices is None: |
| 1315 | continue |
| 1316 | if dep not in inputs: |
| 1317 | continue |
| 1318 | |
| 1319 | changed = True |
searching dependent graphs…