(root:UOp, targets:set[UOp])
| 86 | ]) |
| 87 | |
| 88 | def _deepwalk(root:UOp, targets:set[UOp]) -> tuple[list[UOp], dict[UOp, bool]]: |
| 89 | # compute the target path (top down) |
| 90 | in_target_path: dict[UOp, bool] = {} |
| 91 | root.topovisit(lambda u: any(in_target_path[x] or x in targets for x in u.src), in_target_path) |
| 92 | # don't flow through DETACH or anything not in target path |
| 93 | return [node for node in in_target_path if node.op is not Ops.DETACH and in_target_path[node]], in_target_path |
| 94 | |
| 95 | def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]: |
| 96 | walk, in_target_path = _deepwalk(root, targets) |
no test coverage detected
searching dependent graphs…