MCPcopy Index your code
hub / github.com/apache/tvm / test_two_subfunction

Function test_two_subfunction

tests/python/relax/test_transform_fuse_tir.py:180–217  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

178
179
180def test_two_subfunction():
181 def before():
182 bb = relax.BlockBuilder()
183 x1 = relax.Var("x1", R.Tensor([10, 20], "float32"))
184 with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}):
185 with bb.dataflow():
186 lv1 = bb.emit_te(topi.exp, x1)
187 gv = bb.emit_output(bb.call_te(topi.squeeze, lv1))
188 bb.emit_func_output(gv)
189 mod = bb.get()
190
191 func_gv = mod.get_global_var("fused_exp_squeeze")
192 x = relax.Var("x", R.Tensor([10, 20], "float32"))
193 with bb.function("main", [x]):
194 with bb.dataflow():
195 lv = bb.emit(relax.Call(func_gv, [x]))
196 lv2 = bb.emit(relax.Call(func_gv, [lv]))
197 gv = bb.emit_output(lv2)
198 bb.emit_func_output(gv)
199 return bb.get()
200
201 def expected():
202 def fused_exp_squeeze(x):
203 exp = topi.exp(x)
204 squeeze = topi.squeeze(exp)
205 return squeeze
206
207 bb = relax.BlockBuilder()
208 x = relax.Var("x", R.Tensor([10, 20], "float32"))
209 with bb.function("main", [x]):
210 with bb.dataflow():
211 lv = bb.emit_te(fused_exp_squeeze, x)
212 lv2 = bb.call_te(fused_exp_squeeze, lv)
213 gv = bb.emit_output(lv2)
214 bb.emit_func_output(gv)
215 return bb.get()
216
217 _check(before(), expected())
218
219
220def test_fuse_same_primfunc():

Callers

nothing calls this directly

Calls 3

_checkFunction · 0.70
beforeFunction · 0.70
expectedFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…