()
| 299 | |
| 300 | |
| 301 | def test_simple_module_update(): |
| 302 | @tvm.script.ir_module |
| 303 | class Identity: |
| 304 | @R.function |
| 305 | def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| 306 | with R.dataflow(): |
| 307 | lv0 = x |
| 308 | R.output(lv0) |
| 309 | return lv0 |
| 310 | |
| 311 | root_fn = Identity["main"] |
| 312 | dfb = root_fn.body.blocks[0] |
| 313 | |
| 314 | rwt = DataflowBlockRewrite(dfb, root_fn) |
| 315 | rwt.add(name="tmp", expr=root_fn.params[0], is_dfvar=True) |
| 316 | |
| 317 | new_ir = rwt.mutate_irmodule(Identity) |
| 318 | |
| 319 | # immutatbility |
| 320 | assert new_ir != Identity |
| 321 | assert 2 == len(new_ir["main"].body.blocks[0].bindings) |
| 322 | |
| 323 | @tvm.script.ir_module |
| 324 | class GroundTruth: |
| 325 | @R.function |
| 326 | def main(x: R.Tensor((32, 32), "float32")) -> R.Tensor: |
| 327 | with R.dataflow(): |
| 328 | lv0 = x |
| 329 | tmp: R.Tensor((32, 32), "float32") = x |
| 330 | R.output(lv0) |
| 331 | return lv0 |
| 332 | |
| 333 | tvm.ir.assert_structural_equal(new_ir, GroundTruth) |
| 334 | |
| 335 | |
| 336 | if __name__ == "__main__": |
nothing calls this directly
no test coverage detected
searching dependent graphs…