()
| 44 | |
| 45 | |
| 46 | def test_use_def(): |
| 47 | m = tirx.Var("m", "int64") |
| 48 | n = tirx.Var("n", "int64") |
| 49 | x = rx.Var("x", R.Tensor([m, n], "float16")) |
| 50 | y = rx.Var("y", R.Tensor([n], "float16")) |
| 51 | ib = rx.BlockBuilder() |
| 52 | with ib.function("func", [x, y]): |
| 53 | with ib.dataflow(): |
| 54 | lv0 = ib.emit(rx.op.add(x, y)) |
| 55 | lv1 = ib.emit(rx.op.multiply(lv0, y)) |
| 56 | gv0 = ib.emit_output(lv1) |
| 57 | ib.emit_func_output(gv0) |
| 58 | dfb = ib.get()["func"].body.blocks[0] |
| 59 | udc = udchain(dfb) |
| 60 | assert set(udc[x]) == {lv0} |
| 61 | assert set(udc[y]) == {lv0, lv1} |
| 62 | assert set(udc[lv0]) == {lv1} |
| 63 | assert set(udc[lv1]) == {gv0} |
| 64 | assert set(udc[gv0]) == set() |
| 65 | |
| 66 | |
| 67 | @pytest.mark.parametrize( |
nothing calls this directly
no test coverage detected
searching dependent graphs…