| 73 | |
| 74 | |
| 75 | def test_conv2d_fuse(): |
| 76 | def before(dtype): |
| 77 | bb = relax.BlockBuilder() |
| 78 | |
| 79 | # Grouped function 1 |
| 80 | x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) |
| 81 | w = relax.Var("w", R.Tensor((16, 16, 3, 3), dtype)) |
| 82 | p0 = relax.Var("p0", R.Tensor((), dtype)) |
| 83 | with bb.function("fused_conv2d_add1_add2", [x, w, p0], attrs={"Primitive": True}): |
| 84 | with bb.dataflow(): |
| 85 | lv0 = bb.emit_te( |
| 86 | topi.nn.conv2d, |
| 87 | x, |
| 88 | w, |
| 89 | strides=1, |
| 90 | padding=1, |
| 91 | dilation=1, |
| 92 | primfunc_name_hint="conv2d", |
| 93 | ) |
| 94 | lv1 = bb.emit_te(topi.add, p0, lv0, primfunc_name_hint="add1") |
| 95 | gv = bb.emit_output(bb.call_te(topi.add, lv0, lv1, primfunc_name_hint="add2")) |
| 96 | bb.emit_func_output(gv) |
| 97 | |
| 98 | # Grouped function 2 |
| 99 | x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) |
| 100 | w = relax.Var("w", R.Tensor((16, 16, 1, 1), dtype)) |
| 101 | y = relax.Var("y", R.Tensor((1, 16, 64, 64), dtype)) |
| 102 | with bb.function("fused_conv2d1_add2", [x, w, y], attrs={"Primitive": True}): |
| 103 | with bb.dataflow(): |
| 104 | lv0 = bb.emit_te( |
| 105 | topi.nn.conv2d, |
| 106 | x, |
| 107 | w, |
| 108 | strides=1, |
| 109 | padding=0, |
| 110 | dilation=1, |
| 111 | primfunc_name_hint="conv2d1", |
| 112 | ) |
| 113 | gv = bb.emit_output(bb.call_te(topi.add, lv0, y, primfunc_name_hint="add2")) |
| 114 | bb.emit_func_output(gv) |
| 115 | |
| 116 | # Get the global variables of the grouped functions |
| 117 | mod = bb.get() |
| 118 | fused_conv2d_add1_add2 = mod.get_global_var("fused_conv2d_add1_add2") |
| 119 | fused_conv2d1_add2 = mod.get_global_var("fused_conv2d1_add2") |
| 120 | |
| 121 | # Main function |
| 122 | x = relax.Var("x", R.Tensor((1, 16, 64, 64), dtype)) |
| 123 | w1 = relax.Var("w1", R.Tensor((16, 16, 3, 3), dtype)) |
| 124 | w2 = relax.Var("w2", R.Tensor((16, 16, 1, 1), dtype)) |
| 125 | w3 = relax.Var("w3", R.Tensor((16, 16, 3, 3), dtype)) |
| 126 | with bb.function("main", [x, w1, w2, w3]): |
| 127 | with bb.dataflow(): |
| 128 | lv0 = bb.emit_te(topi.add, x, relax.const(1, dtype)) |
| 129 | lv1 = bb.emit(relax.Call(fused_conv2d_add1_add2, [lv0, w1, relax.const(1, dtype)])) |
| 130 | lv2 = bb.emit_te( |
| 131 | topi.nn.conv2d, |
| 132 | lv1, |