(full_graph, keys=())
| 1107 | |
| 1108 | |
| 1109 | def _optimize_blockwise(full_graph, keys=()): |
| 1110 | keep = {k[0] if type(k) is tuple else k for k in keys} |
| 1111 | layers = full_graph.layers |
| 1112 | dependents = reverse_dict(full_graph.dependencies) |
| 1113 | roots = {k for k in full_graph.layers if not dependents.get(k)} |
| 1114 | stack = list(roots) |
| 1115 | |
| 1116 | out = {} |
| 1117 | dependencies = {} |
| 1118 | seen = set() |
| 1119 | io_names = set() |
| 1120 | |
| 1121 | while stack: |
| 1122 | layer = stack.pop() |
| 1123 | if layer in seen or layer not in layers: |
| 1124 | continue |
| 1125 | seen.add(layer) |
| 1126 | |
| 1127 | # Outer loop walks through possible output Blockwise layers |
| 1128 | if isinstance(layers[layer], Blockwise): |
| 1129 | blockwise_layers = {layer} |
| 1130 | deps = set(blockwise_layers) |
| 1131 | io_names |= layers[layer].io_deps.keys() |
| 1132 | while deps: # we gather as many sub-layers as we can |
| 1133 | dep = deps.pop() |
| 1134 | |
| 1135 | if dep not in layers: |
| 1136 | stack.append(dep) |
| 1137 | continue |
| 1138 | if not isinstance(layers[dep], Blockwise): |
| 1139 | stack.append(dep) |
| 1140 | continue |
| 1141 | if dep != layer and dep in keep: |
| 1142 | stack.append(dep) |
| 1143 | continue |
| 1144 | if layers[dep].concatenate != layers[layer].concatenate: |
| 1145 | stack.append(dep) |
| 1146 | continue |
| 1147 | if ( |
| 1148 | sum(k == dep for k, ind in layers[layer].indices if ind is not None) |
| 1149 | > 1 |
| 1150 | ): |
| 1151 | stack.append(dep) |
| 1152 | continue |
| 1153 | if blockwise_layers and not _can_fuse_annotations( |
| 1154 | layers[next(iter(blockwise_layers))].annotations, |
| 1155 | layers[dep].annotations, |
| 1156 | ): |
| 1157 | stack.append(dep) |
| 1158 | continue |
| 1159 | |
| 1160 | # passed everything, proceed |
| 1161 | blockwise_layers.add(dep) |
| 1162 | |
| 1163 | # traverse further to this child's children |
| 1164 | output_indices = set(layers[dep].output_indices) |
| 1165 | input_indices = { |
| 1166 | i for _, ind in layers[dep].indices if ind for i in ind |
no test coverage detected
searching dependent graphs…