MCPcopy Index your code
hub / github.com/apache/tvm / tir_matmul

Method tir_matmul

tests/python/relax/test_transform_bind_params.py:35–47  ·  view source on GitHub ↗
(x: T.handle, y: T.handle, z: T.handle)

Source from the content-addressed store, hash-verified

33 class InputModule:
34 @T.prim_func(s_tir=True)
35 def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
36 T.func_attr({"global_symbol": "tir_matmul"})
37 A = T.match_buffer(x, (16, 16))
38 B = T.match_buffer(y, (16, 16))
39 C = T.match_buffer(z, (16, 16))
40 for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
41 with T.sblock("matmul"):
42 vi = T.axis.S(16, i0 * 4 + i1)
43 vj = T.axis.S(16, j)
44 vk = T.axis.R(16, k0 * 4 + k1)
45 with T.init():
46 C[vi, vj] = T.float32(0)
47 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
48
49 @R.function
50 def main(x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")) -> R.Tensor(

Callers

nothing calls this directly

Calls 1

initMethod · 0.45

Tested by

no test coverage detected