| 178 | |
| 179 | |
| 180 | def test_two_subfunction(): |
| 181 | def before(): |
| 182 | bb = relax.BlockBuilder() |
| 183 | x1 = relax.Var("x1", R.Tensor([10, 20], "float32")) |
| 184 | with bb.function("fused_exp_squeeze", [x1], attrs={"Primitive": True}): |
| 185 | with bb.dataflow(): |
| 186 | lv1 = bb.emit_te(topi.exp, x1) |
| 187 | gv = bb.emit_output(bb.call_te(topi.squeeze, lv1)) |
| 188 | bb.emit_func_output(gv) |
| 189 | mod = bb.get() |
| 190 | |
| 191 | func_gv = mod.get_global_var("fused_exp_squeeze") |
| 192 | x = relax.Var("x", R.Tensor([10, 20], "float32")) |
| 193 | with bb.function("main", [x]): |
| 194 | with bb.dataflow(): |
| 195 | lv = bb.emit(relax.Call(func_gv, [x])) |
| 196 | lv2 = bb.emit(relax.Call(func_gv, [lv])) |
| 197 | gv = bb.emit_output(lv2) |
| 198 | bb.emit_func_output(gv) |
| 199 | return bb.get() |
| 200 | |
| 201 | def expected(): |
| 202 | def fused_exp_squeeze(x): |
| 203 | exp = topi.exp(x) |
| 204 | squeeze = topi.squeeze(exp) |
| 205 | return squeeze |
| 206 | |
| 207 | bb = relax.BlockBuilder() |
| 208 | x = relax.Var("x", R.Tensor([10, 20], "float32")) |
| 209 | with bb.function("main", [x]): |
| 210 | with bb.dataflow(): |
| 211 | lv = bb.emit_te(fused_exp_squeeze, x) |
| 212 | lv2 = bb.call_te(fused_exp_squeeze, lv) |
| 213 | gv = bb.emit_output(lv2) |
| 214 | bb.emit_func_output(gv) |
| 215 | return bb.get() |
| 216 | |
| 217 | _check(before(), expected()) |
| 218 | |
| 219 | |
| 220 | def test_fuse_same_primfunc(): |