MCPcopy Index your code
hub / github.com/apache/tvm / main

Method main

tests/python/relax/test_transform_codegen_pass.py:341–363  ·  view source on GitHub ↗
(
            x: R.Tensor((1, 4096), dtype="float16"),
            w1: R.Tensor((4096, "r1"), dtype="float16"),
            w2: R.Tensor((4096, "r2"), dtype="float16"),
        )

Source from the content-addressed store, hash-verified

339 class Expected:
340 @R.function
341 def main(
342 x: R.Tensor((1, 4096), dtype="float16"),
343 w1: R.Tensor((4096, "r1"), dtype="float16"),
344 w2: R.Tensor((4096, "r2"), dtype="float16"),
345 ) -> R.Tuple(R.Tensor((1, "r1"), dtype="float16"), R.Tensor((1, "r2"), dtype="float16")):
346 r1 = T.int64()
347 r2 = T.int64()
348 with R.dataflow():
349 lv = R.call_dps_packed(
350 "fused_relax_matmul_cublas",
351 (x, w1),
352 out_sinfo=R.Tensor((1, r1), dtype="float16"),
353 )
354 lv1 = R.call_dps_packed(
355 "fused_relax_matmul_cublas",
356 (x, w2),
357 out_sinfo=R.Tensor((1, r2), dtype="float16"),
358 )
359 gv: R.Tuple(
360 R.Tensor((1, r1), dtype="float16"), R.Tensor((1, r2), dtype="float16")
361 ) = (lv, lv1)
362 R.output(gv)
363 return gv
364
365 after = relax.transform.RunCodegen()(Before)
366 tvm.ir.assert_structural_equal(after["main"], Expected["main"])

Callers

nothing calls this directly

Calls 4

TensorMethod · 0.80
dataflowMethod · 0.80
call_dps_packedMethod · 0.80
outputMethod · 0.80

Tested by

no test coverage detected