MCPcopy Index your code
hub / github.com/apache/tvm / test_conv2d_fuse

Function test_conv2d_fuse

tests/python/relax/test_transform_fuse_tir.py:75–177  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

73
74
75def test_conv2d_fuse():
76 def before(dtype):
77 bb = relax.BlockBuilder()
78
79 # Grouped function 1
80 x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
81 w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype))
82 p0 = relax.Var("p0", R.Tensor((), dtype))
83 with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": True}):
84 with bb.dataflow():
85 lv0 = bb.emit_te(
86 topi.nn.conv2d,
87 x,
88 w,
89 strides=1,
90 padding=1,
91 dilation=1,
92 primfunc_name_hint="conv2d",
93 )
94 lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1")
95 gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, primfunc_name_hint="add2"))
96 bb.emit_func_output(gv)
97
98 # Grouped function 2
99 x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
100 w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype))
101 y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype))
102 with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": True}):
103 with bb.dataflow():
104 lv0 = bb.emit_te(
105 topi.nn.conv2d,
106 x,
107 w,
108 strides=1,
109 padding=0,
110 dilation=1,
111 primfunc_name_hint="conv2d1",
112 )
113 gv = bb.emit_output(bb.call_te(topi.add, lv0, y, primfunc_name_hint="add2"))
114 bb.emit_func_output(gv)
115
116 # Get the global variables of the grouped functions
117 mod = bb.get()
118 fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2")
119 fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2")
120
121 # Main function
122 x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype))
123 w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype))
124 w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype))
125 w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype))
126 with bb.function("main", [x, w1, w2, w3]):
127 with bb.dataflow():
128 lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype))
129 lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, relax.const(1, dtype)]))
130 lv2 = bb.emit_te(
131 topi.nn.conv2d,
132 lv1,

Callers

nothing calls this directly

Calls 3

_checkFunction · 0.70
beforeFunction · 0.70
expectedFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…