MCPcopy Index your code
hub / github.com/apache/tvm / main

Method main

tests/python/relax/test_transform_bind_params.py:80–99  ·  view source on GitHub ↗
(
            x: R.Tensor(("batch", "m"), dtype="float32"),
            w0: R.Tensor(("n", "m"), dtype="float32"),
            b0: R.Tensor(("n",), dtype="float32"),
            w1: R.Tensor(("k", "n"), dtype="float32"),
            b1: R.Tensor(("k",), dtype="float32"),
        )

Source from the content-addressed store, hash-verified

78 class Before:
79 @R.function
80 def main(
81 x: R.Tensor(("batch", "m"), dtype="float32"),
82 w0: R.Tensor(("n", "m"), dtype="float32"),
83 b0: R.Tensor(("n",), dtype="float32"),
84 w1: R.Tensor(("k", "n"), dtype="float32"),
85 b1: R.Tensor(("k",), dtype="float32"),
86 ) -> R.Tensor(("batch", "k"), dtype="float32"):
87 batch = T.Var("batch", "int64")
88 k = T.Var("k", "int64")
89 m = T.Var("m", "int64")
90 n = T.Var("n", "int64")
91 with R.dataflow():
92 lv0 = R.call_dps_packed(
93 "linear0", (x, w0, b0), out_sinfo=R.Tensor((batch, n), dtype="float32")
94 )
95 out = R.call_dps_packed(
96 "linear1", (lv0, w1, b1), out_sinfo=R.Tensor((batch, k), dtype="float32")
97 )
98 R.output(out)
99 return out
100
101 m, n, k = 4, 6, 8
102 w0_tvm = tvm.runtime.tensor(np.random.rand(n, m).astype(np.float32))

Callers 1

Calls 4

TensorMethod · 0.80
dataflowMethod · 0.80
call_dps_packedMethod · 0.80
outputMethod · 0.80

Tested by

no test coverage detected