(
A: T.Buffer((T.int64(4), T.int64(8), T.int64(2048)), "float32"),
T_reshape: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32"),
)
| 1243 | class Module: |
| 1244 | @T.prim_func(private=True, s_tir=True) |
| 1245 | def reshape( |
| 1246 | A: T.Buffer((T.int64(4), T.int64(8), T.int64(2048)), "float32"), |
| 1247 | T_reshape: T.Buffer((T.int64(4), T.int64(8), T.int64(32), T.int64(64)), "float32"), |
| 1248 | ): |
| 1249 | T.func_attr({"op_pattern": 2, "tirx.noalias": True}) |
| 1250 | # with T.sblock("root"): |
| 1251 | for ax0, ax1, ax2, ax3 in T.grid(T.int64(4), T.int64(8), T.int64(32), T.int64(64)): |
| 1252 | with T.sblock("T_reshape"): |
| 1253 | v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) |
| 1254 | T.reads( |
| 1255 | A[ |
| 1256 | ( |
| 1257 | ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + v_ax1) |
| 1258 | // T.int64(8) |
| 1259 | + v_ax0 |
| 1260 | ) |
| 1261 | % T.int64(4), |
| 1262 | ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + v_ax1) % T.int64(8), |
| 1263 | (v_ax2 * T.int64(64) + v_ax3) % T.int64(2048), |
| 1264 | ] |
| 1265 | ) |
| 1266 | T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3]) |
| 1267 | T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] = A[ |
| 1268 | ( |
| 1269 | ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + v_ax1) // T.int64(8) |
| 1270 | + v_ax0 |
| 1271 | ) |
| 1272 | % T.int64(4), |
| 1273 | ((v_ax2 * T.int64(64) + v_ax3) // T.int64(2048) + v_ax1) % T.int64(8), |
| 1274 | (v_ax2 * T.int64(64) + v_ax3) % T.int64(2048), |
| 1275 | ] |
| 1276 | |
| 1277 | @R.function(private=True) |
| 1278 | def fused_reshape( |
no test coverage detected