(self)
| 331 | def bidir_append(ctx, x, b): ctx.append((x.arg if x.op is Ops.CONST else "+", b)) |
| 332 | class TestBidirectional(unittest.TestCase): |
| 333 | def test_simple(self): |
| 334 | a = UOp.const(dtypes.int, 1) |
| 335 | b = UOp.const(dtypes.int, 2) |
| 336 | c = a + b |
| 337 | pm = PatternMatcher([ (UPat(GroupOp.All, name="x"), lambda ctx,x: bidir_append(ctx, x, False)) ]) |
| 338 | bpm = PatternMatcher([ (UPat(GroupOp.All, name="x"), lambda ctx,x: bidir_append(ctx, x, True)) ]) |
| 339 | ctx_list = [] |
| 340 | graph_rewrite(c, pm, ctx=ctx_list, bpm=bpm) |
| 341 | self.assertListEqual(ctx_list, [('+', True), (1, True), (1, False), (2, True), (2, False), ('+', False)]) |
| 342 | |
| 343 | class TestStopEarly(unittest.TestCase): |
| 344 | def test_stop_early(self): |
nothing calls this directly
no test coverage detected