(
A: T.Buffer((T.int64(128), T.int64(64)), "float16"),
w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"),
compute: T.Buffer((T.int64(128),), "float16"),
)
| 1298 | |
| 1299 | @T.prim_func(s_tir=True) |
| 1300 | def encode( |
| 1301 | A: T.Buffer((T.int64(128), T.int64(64)), "float16"), |
| 1302 | w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"), |
| 1303 | compute: T.Buffer((T.int64(128),), "float16"), |
| 1304 | ): |
| 1305 | T.func_attr({"tirx.noalias": True}) |
| 1306 | # with T.sblock("root"): |
| 1307 | max_abs_value = T.sblock_alloc_buffer((T.int64(128),), "float16") |
| 1308 | scale = T.sblock_alloc_buffer((T.int64(128),)) |
| 1309 | for i, k in T.grid(T.int64(128), T.int64(64)): |
| 1310 | with T.sblock("max_abs_value"): |
| 1311 | v_i, v_k = T.axis.remap("SR", [i, k]) |
| 1312 | T.reads(A[v_i, v_k]) |
| 1313 | T.writes(max_abs_value[v_i]) |
| 1314 | with T.init(): |
| 1315 | max_abs_value[v_i] = T.float16(-65504) |
| 1316 | max_abs_value[v_i] = T.max(max_abs_value[v_i], T.fabs(A[v_i, v_k])) |
| 1317 | for i in range(T.int64(128)): |
| 1318 | with T.sblock("scale"): |
| 1319 | v_i = T.axis.spatial(T.int64(128), i) |
| 1320 | T.reads(max_abs_value[v_i]) |
| 1321 | T.writes(scale[v_i]) |
| 1322 | scale[v_i] = T.max( |
| 1323 | T.Cast("float32", max_abs_value[v_i]), T.float32(0.0001) |
| 1324 | ) * T.float32(0.125) |
| 1325 | for j, i, k in T.grid(T.int64(64), T.int64(64), T.int64(2)): |
| 1326 | with T.sblock("w_gathered"): |
| 1327 | v_j, v_i, v_k = T.axis.remap("SSR", [j, i, k]) |
| 1328 | T.reads(A[v_i * T.int64(2) + v_k, v_j], scale[v_i * T.int64(2) + v_k]) |
| 1329 | T.writes(w_gathered[v_j, v_i]) |
| 1330 | with T.init(): |
| 1331 | w_gathered[v_j, v_i] = T.int8(0) |
| 1332 | w_gathered[v_j, v_i] = T.bitwise_or( |
| 1333 | w_gathered[v_j, v_i], |
| 1334 | T.if_then_else( |
| 1335 | v_i * T.int64(2) + v_k < T.int64(128), |
| 1336 | T.shift_left( |
| 1337 | T.bitwise_and( |
| 1338 | T.Cast( |
| 1339 | "int8", |
| 1340 | T.min( |
| 1341 | T.max( |
| 1342 | T.round( |
| 1343 | T.Cast( |
| 1344 | "float32", A[v_i * T.int64(2) + v_k, v_j] |
| 1345 | ) |
| 1346 | / scale[v_i * T.int64(2) + v_k] |
| 1347 | ), |
| 1348 | T.float32(-8), |
| 1349 | ), |
| 1350 | T.float32(7), |
| 1351 | ), |
| 1352 | ), |
| 1353 | T.int8(15), |
| 1354 | ), |
| 1355 | T.Cast("int8", v_k) * T.int8(4), |
| 1356 | ), |
| 1357 | T.int8(0), |
no test coverage detected