(
x: R.Tensor(("batch", "m"), dtype="float32"),
w0: R.Tensor(("n", "m"), dtype="float32"),
b0: R.Tensor(("n",), dtype="float32"),
w1: R.Tensor(("k", "n"), dtype="float32"),
b1: R.Tensor(("k",), dtype="float32"),
)
| 78 | class Before: |
| 79 | @R.function |
| 80 | def main( |
| 81 | x: R.Tensor(("batch", "m"), dtype="float32"), |
| 82 | w0: R.Tensor(("n", "m"), dtype="float32"), |
| 83 | b0: R.Tensor(("n",), dtype="float32"), |
| 84 | w1: R.Tensor(("k", "n"), dtype="float32"), |
| 85 | b1: R.Tensor(("k",), dtype="float32"), |
| 86 | ) -> R.Tensor(("batch", "k"), dtype="float32"): |
| 87 | batch = T.Var("batch", "int64") |
| 88 | k = T.Var("k", "int64") |
| 89 | m = T.Var("m", "int64") |
| 90 | n = T.Var("n", "int64") |
| 91 | with R.dataflow(): |
| 92 | lv0 = R.call_dps_packed( |
| 93 | "linear0", (x, w0, b0), out_sinfo=R.Tensor((batch, n), dtype="float32") |
| 94 | ) |
| 95 | out = R.call_dps_packed( |
| 96 | "linear1", (lv0, w1, b1), out_sinfo=R.Tensor((batch, k), dtype="float32") |
| 97 | ) |
| 98 | R.output(out) |
| 99 | return out |
| 100 | |
| 101 | m, n, k = 4, 6, 8 |
| 102 | w0_tvm = tvm.runtime.tensor(np.random.rand(n, m).astype(np.float32)) |
no test coverage detected