MCPcopy
hub / github.com/apache/tvm / test_fp8_packing

Function test_fp8_packing

tests/python/codegen/test_target_codegen_cuda_fp8.py:98–142  ·  view source on GitHub ↗
(dtype)

Source from the content-addressed store, hash-verified

96@pytest.mark.gpu
97@pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0")
98def test_fp8_packing(dtype):
99 length = 64
100 vector_length = 4
101 native_dtype, packed_dtype = (f"{dtype}x{vector_length}", "uint32")
102
103 def _create_mod(native_dtype, packed_dtype, length):
104 @I.ir_module(s_tir=True)
105 class Module:
106 @T.prim_func(s_tir=True)
107 def main(
108 A: T.Buffer((length,), native_dtype),
109 R: T.Buffer((length,), packed_dtype),
110 B: T.Buffer((length,), native_dtype),
111 ):
112 T.func_attr({"tirx.noalias": True})
113 for i_0 in T.thread_binding(2, thread="blockIdx.x"):
114 for i_1 in T.thread_binding(32, thread="threadIdx.x"):
115 with T.sblock("R"):
116 v_i = T.axis.spatial(length, i_0 * 32 + i_1)
117 T.reads(A[v_i])
118 T.writes(R[v_i])
119 R[v_i] = T.reinterpret(packed_dtype, A[v_i])
120 for i_0 in T.thread_binding(2, thread="blockIdx.x"):
121 for i_1 in T.thread_binding(32, thread="threadIdx.x"):
122 with T.sblock("B"):
123 v_i = T.axis.spatial(length, i_0 * 32 + i_1)
124 T.reads(R[v_i])
125 T.writes(B[v_i])
126 B[v_i] = T.reinterpret(native_dtype, R[v_i])
127
128 return Module
129
130 mod = _create_mod(native_dtype, packed_dtype, length)
131 target = "cuda"
132 f = tvm.compile(mod, target=target)
133 dev = tvm.device(target, 0)
134
135 np_shape = (length, vector_length)
136 a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(dtype)
137 a = tvm.runtime.empty(shape=(length,), dtype=native_dtype, device=dev)
138 r = tvm.runtime.empty(shape=(length,), dtype=packed_dtype, device=dev)
139 b = tvm.runtime.empty(shape=(length,), dtype=native_dtype, device=dev)
140 a.copyfrom(a_np)
141 f(a, r, b)
142 tvm.testing.assert_allclose(a.numpy().astype("float16"), b.numpy().astype("float16"))
143
144
145native_dtype, promoted_dtype, numpytype = tvm.testing.parameters(

Callers

nothing calls this directly

Calls 9

_create_modFunction · 0.85
uniformMethod · 0.80
copyfromMethod · 0.80
numpyMethod · 0.80
fFunction · 0.50
compileMethod · 0.45
deviceMethod · 0.45
astypeMethod · 0.45
emptyMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…