MCPcopy Index your code
hub / github.com/apache/tvm / test_emits

Function test_emits

tests/python/relax/test_tvmscript_ir_builder.py:75–121  ·  view source on GitHub ↗

Tests for R.emit, R.emit_match_cast, R.emit_var_binding @R.function def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) -> R.Shape(ndim=2): m = T.int64() n = T.int64() gv: R.Tensor((m,), dtype="float32") = R.match_cast(x, R.Tensor((m,), dtype="float32

()

Source from the content-addressed store, hash-verified

73
74
75def test_emits():
76 """Tests for R.emit, R.emit_match_cast, R.emit_var_binding
77
78 @R.function
79 def foo(x: R.Tensor(dtype="float32"), y: R.Tensor(dtype="float32")) -> R.Shape(ndim=2):
80 m = T.int64()
81 n = T.int64()
82 gv: R.Tensor((m,), dtype="float32") = R.match_cast(x, R.Tensor((m,), dtype="float32"))
83 gv1: R.Tensor((n,), dtype="float32") = R.match_cast(y, R.Tensor((n,), dtype="float32"))
84 v: R.Tensor((n,), dtype="float32") = gv1
85 return R.shape([m, n * 2])
86 """
87 # create with Script IRBuilder
88 with IRBuilder() as ir_builder:
89 with R.function():
90 R.func_name("foo")
91 x = R.arg("x", relax.TensorStructInfo(ndim=-1, dtype="float32"))
92 y = R.arg("y", relax.TensorStructInfo(ndim=-1, dtype="float32"))
93 m = tirx.Var("m", dtype="int64")
94 n = tirx.Var("n", dtype="int64")
95 _ = R.emit_match_cast(x, relax.TensorStructInfo((m,), "float32"))
96 y1 = R.emit_match_cast(y, relax.TensorStructInfo((n,), "float32"))
97 v = relax.Var("v", relax.TensorStructInfo((n,), "float32"))
98 vb = relax.VarBinding(v, y1)
99 v = R.emit_var_binding(vb)
100 R.emit(v)
101
102 IRBuilder.name("v", v)
103 R.func_ret_value(relax.ShapeExpr([m, n * 2]))
104 func = ir_builder.get()
105
106 # create with BlockBuilder
107 m = tirx.Var("m", dtype="int64")
108 n = tirx.Var("n", dtype="int64")
109 x = relax.Var("x", relax.TensorStructInfo(dtype="float32", ndim=-1))
110 y = relax.Var("y", relax.TensorStructInfo(dtype="float32", ndim=-1))
111 v = relax.Var("v", relax.TensorStructInfo((n,), "float32"))
112 bb = relax.BlockBuilder()
113 with bb.function("foo", (x, y)):
114 _ = bb.match_cast(x, relax.TensorStructInfo((m,), "float32"))
115 y1 = bb.match_cast(y, relax.TensorStructInfo((n,), "float32"))
116 bb.emit_normalized(relax.VarBinding(v, y1))
117 bb.emit(v)
118 bb.emit_func_output(relax.ShapeExpr([m, n * 2]))
119 mod = bb.get()
120
121 tvm.ir.assert_structural_equal(func, mod["foo"])
122
123
124def test_dataflow_block():

Callers

nothing calls this directly

Calls 10

emitMethod · 0.95
functionMethod · 0.95
match_castMethod · 0.95
emit_normalizedMethod · 0.95
emit_func_outputMethod · 0.95
getMethod · 0.95
IRBuilderClass · 0.90
functionMethod · 0.45
nameMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…