MCPcopy Index your code
hub / github.com/apache/tvm / matmul

Method matmul

tests/python/relax/test_pytorch_integration.py:66–82  ·  view source on GitHub ↗

TIR function for matrix multiplication.

(
        var_A: T.handle,
        var_B: T.handle,
        var_C: T.handle,
    )

Source from the content-addressed store, hash-verified

64
65 @T.prim_func(s_tir=True)
66 def matmul(
67 var_A: T.handle,
68 var_B: T.handle,
69 var_C: T.handle,
70 ):
71 """TIR function for matrix multiplication."""
72 n = T.int32()
73 A = T.match_buffer(var_A, (n, 16), "float32")
74 B = T.match_buffer(var_B, (16, 20), "float32")
75 C = T.match_buffer(var_C, (n, 20), "float32")
76
77 for i, j, k in T.grid(n, 20, 16):
78 with T.sblock("block"):
79 vi, vj, vk = T.axis.remap("SSR", [i, j, k])
80 with T.init():
81 C[vi, vj] = T.float32(0)
82 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
83
84 @I.pyfunc
85 def my_identity_func(self, x: torch.Tensor) -> torch.Tensor:

Calls 2

remapMethod · 0.80
initMethod · 0.45

Tested by

no test coverage detected