()
| 51 | |
| 52 | |
| 53 | def test_simple_add(): |
| 54 | root_fn = Identity["main"] |
| 55 | dfb = root_fn.body.blocks[0] |
| 56 | |
| 57 | rwt = DataflowBlockRewrite(dfb, root_fn) |
| 58 | rwt.add(name="tmp", expr=Identity["main"].params[0], is_dfvar=True) |
| 59 | |
| 60 | assert_immutability(rwt, dfb, root_fn) |
| 61 | |
| 62 | # check "tmp" added |
| 63 | assert "tmp" in name_to_binding(rwt.mutated_root_fn()) |
| 64 | |
| 65 | @tvm.script.ir_module |
| 66 | class GroundTruth: |
| 67 | @R.function |
| 68 | def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| 69 | with R.dataflow(): |
| 70 | lv0 = x |
| 71 | tmp: R.Tensor((32, 32), "float32") = x |
| 72 | R.output(lv0) |
| 73 | return lv0 |
| 74 | |
| 75 | tvm.ir.assert_structural_equal(rwt.mutated_root_fn(), GroundTruth["main"]) |
| 76 | |
| 77 | |
| 78 | def test_simple_auto_add_var(): |
nothing calls this directly
no test coverage detected
searching dependent graphs…