(
x: R.Tensor((1, 4096), dtype="float16"),
w1: R.Tensor((4096, "r1"), dtype="float16"),
w2: R.Tensor((4096, "r2"), dtype="float16"),
)
| 339 | class Expected: |
| 340 | @R.function |
| 341 | def main( |
| 342 | x: R.Tensor((1, 4096), dtype="float16"), |
| 343 | w1: R.Tensor((4096, "r1"), dtype="float16"), |
| 344 | w2: R.Tensor((4096, "r2"), dtype="float16"), |
| 345 | ) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")): |
| 346 | r1 = T.int64() |
| 347 | r2 = T.int64() |
| 348 | with R.dataflow(): |
| 349 | lv = R.call_dps_packed( |
| 350 | "fused_relax_matmul_cublas", |
| 351 | (x, w1), |
| 352 | out_sinfo=R.Tensor((1, r1), dtype="float16"), |
| 353 | ) |
| 354 | lv1 = R.call_dps_packed( |
| 355 | "fused_relax_matmul_cublas", |
| 356 | (x, w2), |
| 357 | out_sinfo=R.Tensor((1, r2), dtype="float16"), |
| 358 | ) |
| 359 | gv: R.Tuple( |
| 360 | R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16") |
| 361 | ) = (lv, lv1) |
| 362 | R.output(gv) |
| 363 | return gv |
| 364 | |
| 365 | after = relax.transform.RunCodegen()(Before) |
| 366 | tvm.ir.assert_structural_equal(after["main"], Expected["main"]) |
nothing calls this directly
no test coverage detected