MCPcopy Index your code
hub / github.com/apache/tvm / encode

Method encode

tests/python/relax/test_codegen_cutlass.py:1300–1365  ·  view source on GitHub ↗
(
            A: T.Buffer((T.int64(128), T.int64(64)), "float16"),
            w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"),
            compute: T.Buffer((T.int64(128),), "float16"),
        )

Source from the content-addressed store, hash-verified

1298
1299 @T.prim_func(s_tir=True)
1300 def encode(
1301 A: T.Buffer((T.int64(128), T.int64(64)), "float16"),
1302 w_gathered: T.Buffer((T.int64(64), T.int64(64)), "int8"),
1303 compute: T.Buffer((T.int64(128),), "float16"),
1304 ):
1305 T.func_attr({"tirx.noalias": True})
1306 # with T.sblock("root"):
1307 max_abs_value = T.sblock_alloc_buffer((T.int64(128),), "float16")
1308 scale = T.sblock_alloc_buffer((T.int64(128),))
1309 for i, k in T.grid(T.int64(128), T.int64(64)):
1310 with T.sblock("max_abs_value"):
1311 v_i, v_k = T.axis.remap("SR", [i, k])
1312 T.reads(A[v_i, v_k])
1313 T.writes(max_abs_value[v_i])
1314 with T.init():
1315 max_abs_value[v_i] = T.float16(-65504)
1316 max_abs_value[v_i] = T.max(max_abs_value[v_i], T.fabs(A[v_i, v_k]))
1317 for i in range(T.int64(128)):
1318 with T.sblock("scale"):
1319 v_i = T.axis.spatial(T.int64(128), i)
1320 T.reads(max_abs_value[v_i])
1321 T.writes(scale[v_i])
1322 scale[v_i] = T.max(
1323 T.Cast("float32", max_abs_value[v_i]), T.float32(0.0001)
1324 ) * T.float32(0.125)
1325 for j, i, k in T.grid(T.int64(64), T.int64(64), T.int64(2)):
1326 with T.sblock("w_gathered"):
1327 v_j, v_i, v_k = T.axis.remap("SSR", [j, i, k])
1328 T.reads(A[v_i * T.int64(2) + v_k, v_j], scale[v_i * T.int64(2) + v_k])
1329 T.writes(w_gathered[v_j, v_i])
1330 with T.init():
1331 w_gathered[v_j, v_i] = T.int8(0)
1332 w_gathered[v_j, v_i] = T.bitwise_or(
1333 w_gathered[v_j, v_i],
1334 T.if_then_else(
1335 v_i * T.int64(2) + v_k < T.int64(128),
1336 T.shift_left(
1337 T.bitwise_and(
1338 T.Cast(
1339 "int8",
1340 T.min(
1341 T.max(
1342 T.round(
1343 T.Cast(
1344 "float32", A[v_i * T.int64(2) + v_k, v_j]
1345 )
1346 / scale[v_i * T.int64(2) + v_k]
1347 ),
1348 T.float32(-8),
1349 ),
1350 T.float32(7),
1351 ),
1352 ),
1353 T.int8(15),
1354 ),
1355 T.Cast("int8", v_k) * T.int8(4),
1356 ),
1357 T.int8(0),

Callers 15

find_shard_indexFunction · 0.80
codegen_cuda_printfFunction · 0.80
sendjsonFunction · 0.80
_accept_connFunction · 0.80
_connect_proxy_loopFunction · 0.80
ret_valueMethod · 0.80
_pair_upMethod · 0.80
_connectFunction · 0.80
_create_cuda_moduleMethod · 0.80
_compile_cuda_nvrtcFunction · 0.80
_link_nvshmem_nvrtcFunction · 0.80
postFunction · 0.80

Calls 5

remapMethod · 0.80
maxMethod · 0.80
spatialMethod · 0.80
minMethod · 0.80
initMethod · 0.45

Tested by

no test coverage detected