Tests for R.emit, R.emit_match_cast, R.emit_var_binding @R.function def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) -> R.Shape(ndim=2): m = T.int64() n = T.int64() gv: R.Tensor((m,), dtype="float32") = R.match_cast(x, R.Tensor((m,), dtype="float32
()
| 73 | |
| 74 | |
| 75 | def test_emits(): |
| 76 | """Tests for R.emit, R.emit_match_cast, R.emit_var_binding |
| 77 | |
| 78 | @R.function |
| 79 | def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) -> R.Shape(ndim=2): |
| 80 | m = T.int64() |
| 81 | n = T.int64() |
| 82 | gv: R.Tensor((m,), dtype="float32") = R.match_cast(x, R.Tensor((m,), dtype="float32")) |
| 83 | gv1: R.Tensor((n,), dtype="float32") = R.match_cast(y, R.Tensor((n,), dtype="float32")) |
| 84 | v: R.Tensor((n,), dtype="float32") = gv1 |
| 85 | return R.shape([m, n * 2]) |
| 86 | """ |
| 87 | # create with Script IRBuilder |
| 88 | with IRBuilder() as ir_builder: |
| 89 | with R.function(): |
| 90 | R.func_name("foo") |
| 91 | x = R.arg("x", relax.TensorStructInfo(ndim=-1, dtype="float32")) |
| 92 | y = R.arg("y", relax.TensorStructInfo(ndim=-1, dtype="float32")) |
| 93 | m = tirx.Var("m", dtype="int64") |
| 94 | n = tirx.Var("n", dtype="int64") |
| 95 | _ = R.emit_match_cast(x, relax.TensorStructInfo((m,), "float32")) |
| 96 | y1 = R.emit_match_cast(y, relax.TensorStructInfo((n,), "float32")) |
| 97 | v = relax.Var("v", relax.TensorStructInfo((n,), "float32")) |
| 98 | vb = relax.VarBinding(v, y1) |
| 99 | v = R.emit_var_binding(vb) |
| 100 | R.emit(v) |
| 101 | |
| 102 | IRBuilder.name("v", v) |
| 103 | R.func_ret_value(relax.ShapeExpr([m, n * 2])) |
| 104 | func = ir_builder.get() |
| 105 | |
| 106 | # create with BlockBuilder |
| 107 | m = tirx.Var("m", dtype="int64") |
| 108 | n = tirx.Var("n", dtype="int64") |
| 109 | x = relax.Var("x", relax.TensorStructInfo(dtype="float32", ndim=-1)) |
| 110 | y = relax.Var("y", relax.TensorStructInfo(dtype="float32", ndim=-1)) |
| 111 | v = relax.Var("v", relax.TensorStructInfo((n,), "float32")) |
| 112 | bb = relax.BlockBuilder() |
| 113 | with bb.function("foo", (x, y)): |
| 114 | _ = bb.match_cast(x, relax.TensorStructInfo((m,), "float32")) |
| 115 | y1 = bb.match_cast(y, relax.TensorStructInfo((n,), "float32")) |
| 116 | bb.emit_normalized(relax.VarBinding(v, y1)) |
| 117 | bb.emit(v) |
| 118 | bb.emit_func_output(relax.ShapeExpr([m, n * 2])) |
| 119 | mod = bb.get() |
| 120 | |
| 121 | tvm.ir.assert_structural_equal(func, mod["foo"]) |
| 122 | |
| 123 | |
| 124 | def test_dataflow_block(): |
nothing calls this directly
no test coverage detected
searching dependent graphs…