Traverse the expression graph and apply fusion
(expr)
| 3252 | |
| 3253 | |
| 3254 | def optimize_blockwise_fusion(expr): |
| 3255 | """Traverse the expression graph and apply fusion""" |
| 3256 | |
| 3257 | def _fusion_pass(expr): |
| 3258 | # Full pass to find global dependencies |
| 3259 | seen = set() |
| 3260 | stack = [expr] |
| 3261 | dependents = defaultdict(set) |
| 3262 | dependencies = {} |
| 3263 | expr_mapping = {} |
| 3264 | |
| 3265 | while stack: |
| 3266 | next = stack.pop() |
| 3267 | |
| 3268 | if next._name in seen: |
| 3269 | continue |
| 3270 | seen.add(next._name) |
| 3271 | |
| 3272 | if is_valid_blockwise_op(next): |
| 3273 | dependencies[next._name] = set() |
| 3274 | if next._name not in dependents: |
| 3275 | dependents[next._name] = set() |
| 3276 | expr_mapping[next._name] = next |
| 3277 | |
| 3278 | for operand in next.dependencies(): |
| 3279 | stack.append(operand) |
| 3280 | if is_valid_blockwise_op(operand): |
| 3281 | if next._name in dependencies: |
| 3282 | dependencies[next._name].add(operand._name) |
| 3283 | dependents[operand._name].add(next._name) |
| 3284 | expr_mapping[operand._name] = operand |
| 3285 | expr_mapping[next._name] = next |
| 3286 | |
| 3287 | # Traverse each "root" until we find a fusable sub-group. |
| 3288 | # Here we use root to refer to a Blockwise Expr node that |
| 3289 | # has no Blockwise dependents |
| 3290 | roots = [ |
| 3291 | expr_mapping[k] |
| 3292 | for k, v in dependents.items() |
| 3293 | if v == set() |
| 3294 | or all(not is_valid_blockwise_op(expr_mapping[_expr]) for _expr in v) |
| 3295 | ] |
| 3296 | while roots: |
| 3297 | root = roots.pop() |
| 3298 | seen = set() |
| 3299 | stack = [root] |
| 3300 | group = [] |
| 3301 | while stack: |
| 3302 | next = stack.pop() |
| 3303 | |
| 3304 | if next._name in seen: |
| 3305 | continue |
| 3306 | seen.add(next._name) |
| 3307 | |
| 3308 | group.append(next) |
| 3309 | for dep_name in dependencies[next._name]: |
| 3310 | dep = expr_mapping[dep_name] |
| 3311 |
no test coverage detected
searching dependent graphs…