MCPcopy Index your code
hub / github.com/apache/tvm / fused

Method fused

tests/python/relax/test_transform_fuse_tir.py:1005–1022  ·  view source on GitHub ↗
(
            X: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float32"),
            Y: T.Buffer((T.int64(2048), T.int64(128)), "float32"),
            rotary: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float32"),
            m: T.int64,
        )

Source from the content-addressed store, hash-verified

1003 class Expected:
1004 @T.prim_func(private=True, s_tir=True)
1005 def fused(
1006 X: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float32"),
1007 Y: T.Buffer((T.int64(2048), T.int64(128)), "float32"),
1008 rotary: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float32"),
1009 m: T.int64,
1010 ):
1011 T.func_attr({"tirx.noalias": True})
1012 T_add = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)))
1013 for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
1014 with T.sblock("T_add"):
1015 v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
1016 T_add[v_ax0, v_ax1, v_ax2, v_ax3] = (
1017 X[v_ax0, v_ax1, v_ax2, v_ax3] + X[v_ax0, v_ax1, v_ax2, v_ax3]
1018 )
1019 for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)):
1020 with T.sblock("rotary"):
1021 v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
1022 rotary[v0, v1, v2, v3] = Y[m + v1 - T.int64(1), v3] * T_add[v0, v1, v2, v3]
1023
1024 @R.function
1025 def main(

Callers

nothing calls this directly

Calls 2

remapMethod · 0.80
initMethod · 0.45

Tested by

no test coverage detected