| 1003 | class Expected: |
| 1004 | @T.prim_func(private=True, s_tir=True) |
| 1005 | def fused( |
| 1006 | X: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float32"), |
| 1007 | Y: T.Buffer((T.int64(2048), T.int64(128)), "float32"), |
| 1008 | rotary: T.Buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128)), "float32"), |
| 1009 | m: T.int64, |
| 1010 | ): |
| 1011 | T.func_attr({"tirx.noalias": True}) |
| 1012 | T_add = T.sblock_alloc_buffer((T.int64(1), T.int64(1), T.int64(32), T.int64(128))) |
| 1013 | for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)): |
| 1014 | with T.sblock("T_add"): |
| 1015 | v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
| 1016 | T_add[v_ax0, v_ax1, v_ax2, v_ax3] = ( |
| 1017 | X[v_ax0, v_ax1, v_ax2, v_ax3] + X[v_ax0, v_ax1, v_ax2, v_ax3] |
| 1018 | ) |
| 1019 | for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(1), T.int64(32), T.int64(128)): |
| 1020 | with T.sblock("rotary"): |
| 1021 | v0, v1, v2, v3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) |
| 1022 | rotary[v0, v1, v2, v3] = Y[m + v1 - T.int64(1), v3] * T_add[v0, v1, v2, v3] |
| 1023 | |
| 1024 | @R.function |
| 1025 | def main( |