| 149 | |
| 150 | |
| 151 | def test_dataflow_fold(): |
| 152 | @tvm.script.ir_module |
| 153 | class Module: |
| 154 | @T.prim_func(s_tir=True) |
| 155 | def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None: |
| 156 | for i, j in T.grid(16, 16): |
| 157 | with T.sblock("identity"): |
| 158 | vi, vj = T.axis.remap("SS", [i, j]) |
| 159 | B[vi, vj] = A[vi, vj] |
| 160 | |
| 161 | @R.function |
| 162 | def before(c0: R.Tensor((16, 16), "float32")): |
| 163 | cls = Module |
| 164 | with R.dataflow(): |
| 165 | gv0 = relax.call_tir(cls.identity, (c0,), R.Tensor((16, 16), dtype="float32")) |
| 166 | R.output(gv0) |
| 167 | return gv0 |
| 168 | |
| 169 | @R.function |
| 170 | def expected(c1: R.Tensor((16, 16), "float32")): |
| 171 | return c1 |
| 172 | |
| 173 | c0_np = np.arange(16 * 16).astype("float32").reshape(16, 16) |
| 174 | c1_np = c0_np |
| 175 | before = gen_mod(Module, "before", {"c0": c0_np}) |
| 176 | expected = gen_mod(Module, "expected", {"c1": c1_np}) |
| 177 | after = relax.transform.FoldConstant()(before) |
| 178 | tvm.ir.assert_structural_equal(after, expected) |
| 179 | |
| 180 | |
| 181 | def test_fold_mixed_case(): |