MCPcopy Index your code
hub / github.com/apache/tvm / test_dataflow_fold

Function test_dataflow_fold

tests/python/relax/test_transform_fold_constant.py:151–178  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

149
150
151def test_dataflow_fold():
152 @tvm.script.ir_module
153 class Module:
154 @T.prim_func(s_tir=True)
155 def identity(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")) -> None:
156 for i, j in T.grid(16, 16):
157 with T.sblock("identity"):
158 vi, vj = T.axis.remap("SS", [i, j])
159 B[vi, vj] = A[vi, vj]
160
161 @R.function
162 def before(c0: R.Tensor((16, 16), "float32")):
163 cls = Module
164 with R.dataflow():
165 gv0 = relax.call_tir(cls.identity, (c0,), R.Tensor((16, 16), dtype="float32"))
166 R.output(gv0)
167 return gv0
168
169 @R.function
170 def expected(c1: R.Tensor((16, 16), "float32")):
171 return c1
172
173 c0_np = np.arange(16 * 16).astype("float32").reshape(16, 16)
174 c1_np = c0_np
175 before = gen_mod(Module, "before", {"c0": c0_np})
176 expected = gen_mod(Module, "expected", {"c1": c1_np})
177 after = relax.transform.FoldConstant()(before)
178 tvm.ir.assert_structural_equal(after, expected)
179
180
181def test_fold_mixed_case():

Callers

nothing calls this directly

Calls 4

gen_modFunction · 0.85
reshapeMethod · 0.45
astypeMethod · 0.45
arangeMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…