| 179 | |
| 180 | |
| 181 | def test_fold_mixed_case(): |
| 182 | @tvm.script.ir_module |
| 183 | class Module: |
| 184 | # TIR function can handle different cases. |
| 185 | @T.prim_func(s_tir=True) |
| 186 | def addone(a: T.handle, b: T.handle) -> None: |
| 187 | n = T.int32() |
| 188 | m = T.int32() |
| 189 | A = T.match_buffer(a, (n, m)) |
| 190 | B = T.match_buffer(b, (n, m)) |
| 191 | for i, j in T.grid(n, m): |
| 192 | with T.sblock("addone"): |
| 193 | vi, vj = T.axis.remap("SS", [i, j]) |
| 194 | B[vi, vj] = A[vi, vj] + T.float32(1) |
| 195 | |
| 196 | @T.prim_func(s_tir=True) |
| 197 | def sub( |
| 198 | A: T.Buffer((16, 16), "float32"), |
| 199 | B: T.Buffer((16, 16), "float32"), |
| 200 | C: T.Buffer((16, 16), "float32"), |
| 201 | ) -> None: |
| 202 | for i, j in T.grid(16, 16): |
| 203 | with T.sblock("sub"): |
| 204 | vi, vj = T.axis.remap("SS", [i, j]) |
| 205 | C[vi, vj] = A[vi, vj] - B[vi, vj] |
| 206 | |
| 207 | @R.function |
| 208 | def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2)): |
| 209 | n, m = T.int64(), T.int64() |
| 210 | cls = Module |
| 211 | x0 = R.match_cast(x, R.Tensor((n, m), "float32")) |
| 212 | # this line cannot be folded because n is unknown |
| 213 | lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16), dtype="float32")) |
| 214 | # this line can be folded |
| 215 | lv1 = relax.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype="float32")) |
| 216 | # this line can be folded because all inputs are const |
| 217 | lv2 = relax.call_tir(cls.sub, (c0, lv1), R.Tensor((16, 16), dtype="float32")) |
| 218 | # this line can not be folded because x's shape is unknown |
| 219 | lv3 = relax.call_tir(cls.sub, (lv2, x), R.Tensor((16, 16), dtype="float32")) |
| 220 | return (lv0, lv3) |
| 221 | |
| 222 | @R.function |
| 223 | def expected( |
| 224 | c0: R.Tensor((16, 16), "float32"), |
| 225 | c1: R.Tensor((16, 16), "float32"), |
| 226 | c2: R.Tensor((16, 16), "float32"), |
| 227 | x: R.Tensor("float32", ndim=2), |
| 228 | ): |
| 229 | n, m = T.int64(), T.int64() |
| 230 | cls = Module |
| 231 | x0 = R.match_cast(x, R.Tensor((n, m), "float32")) |
| 232 | # this line cannot be folded because n is unknown |
| 233 | lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16), dtype="float32")) |
| 234 | # this line can not be folded because x's shape is unknown |
| 235 | lv3 = relax.call_tir(cls.sub, (c2, x), R.Tensor((16, 16), dtype="float32")) |
| 236 | return (lv0, lv3) |
| 237 | |
| 238 | c0_np = np.arange(16 * 16).astype("float32").reshape(16, 16) |