MCPcopy Index your code
hub / github.com/apache/tvm / test_fold_mixed_case

Function test_fold_mixed_case

tests/python/relax/test_transform_fold_constant.py:181–245  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

179
180
181def test_fold_mixed_case():
182 @tvm.script.ir_module
183 class Module:
184 # TIR function can handle different cases.
185 @T.prim_func(s_tir=True)
186 def addone(a: T.handle, b: T.handle) -> None:
187 n = T.int32()
188 m = T.int32()
189 A = T.match_buffer(a, (n, m))
190 B = T.match_buffer(b, (n, m))
191 for i, j in T.grid(n, m):
192 with T.sblock("addone"):
193 vi, vj = T.axis.remap("SS", [i, j])
194 B[vi, vj] = A[vi, vj] + T.float32(1)
195
196 @T.prim_func(s_tir=True)
197 def sub(
198 A: T.Buffer((16, 16), "float32"),
199 B: T.Buffer((16, 16), "float32"),
200 C: T.Buffer((16, 16), "float32"),
201 ) -> None:
202 for i, j in T.grid(16, 16):
203 with T.sblock("sub"):
204 vi, vj = T.axis.remap("SS", [i, j])
205 C[vi, vj] = A[vi, vj] - B[vi, vj]
206
207 @R.function
208 def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32", ndim=2)):
209 n, m = T.int64(), T.int64()
210 cls = Module
211 x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
212 # this line cannot be folded because n is unknown
213 lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16), dtype="float32"))
214 # this line can be folded
215 lv1 = relax.call_tir(cls.addone, (c0,), R.Tensor((16, 16), dtype="float32"))
216 # this line can be folded because all inputs are const
217 lv2 = relax.call_tir(cls.sub, (c0, lv1), R.Tensor((16, 16), dtype="float32"))
218 # this line can not be folded because x's shape is unknown
219 lv3 = relax.call_tir(cls.sub, (lv2, x), R.Tensor((16, 16), dtype="float32"))
220 return (lv0, lv3)
221
222 @R.function
223 def expected(
224 c0: R.Tensor((16, 16), "float32"),
225 c1: R.Tensor((16, 16), "float32"),
226 c2: R.Tensor((16, 16), "float32"),
227 x: R.Tensor("float32", ndim=2),
228 ):
229 n, m = T.int64(), T.int64()
230 cls = Module
231 x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
232 # this line cannot be folded because n is unknown
233 lv0 = relax.call_tir(cls.addone, (c0,), R.Tensor((n, 16), dtype="float32"))
234 # this line can not be folded because x's shape is unknown
235 lv3 = relax.call_tir(cls.sub, (c2, x), R.Tensor((16, 16), dtype="float32"))
236 return (lv0, lv3)
237
238 c0_np = np.arange(16 * 16).astype("float32").reshape(16, 16)

Callers

nothing calls this directly

Calls 4

gen_modFunction · 0.85
reshapeMethod · 0.45
astypeMethod · 0.45
arangeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…