(dtype)
| 96 | @pytest.mark.gpu |
| 97 | @pytest.mark.skipif(not env.has_cuda_compute(10), reason="need cuda compute >= 10.0") |
| 98 | def test_fp8_packing(dtype): |
| 99 | length = 64 |
| 100 | vector_length = 4 |
| 101 | native_dtype, packed_dtype = (f"{dtype}x{vector_length}", "uint32") |
| 102 | |
| 103 | def _create_mod(native_dtype, packed_dtype, length): |
| 104 | @I.ir_module(s_tir=True) |
| 105 | class Module: |
| 106 | @T.prim_func(s_tir=True) |
| 107 | def main( |
| 108 | A: T.Buffer((length,), native_dtype), |
| 109 | R: T.Buffer((length,), packed_dtype), |
| 110 | B: T.Buffer((length,), native_dtype), |
| 111 | ): |
| 112 | T.func_attr({"tirx.noalias": True}) |
| 113 | for i_0 in T.thread_binding(2, thread="blockIdx.x"): |
| 114 | for i_1 in T.thread_binding(32, thread="threadIdx.x"): |
| 115 | with T.sblock("R"): |
| 116 | v_i = T.axis.spatial(length, i_0 * 32 + i_1) |
| 117 | T.reads(A[v_i]) |
| 118 | T.writes(R[v_i]) |
| 119 | R[v_i] = T.reinterpret(packed_dtype, A[v_i]) |
| 120 | for i_0 in T.thread_binding(2, thread="blockIdx.x"): |
| 121 | for i_1 in T.thread_binding(32, thread="threadIdx.x"): |
| 122 | with T.sblock("B"): |
| 123 | v_i = T.axis.spatial(length, i_0 * 32 + i_1) |
| 124 | T.reads(R[v_i]) |
| 125 | T.writes(B[v_i]) |
| 126 | B[v_i] = T.reinterpret(native_dtype, R[v_i]) |
| 127 | |
| 128 | return Module |
| 129 | |
| 130 | mod = _create_mod(native_dtype, packed_dtype, length) |
| 131 | target = "cuda" |
| 132 | f = tvm.compile(mod, target=target) |
| 133 | dev = tvm.device(target, 0) |
| 134 | |
| 135 | np_shape = (length, vector_length) |
| 136 | a_np = np.random.uniform(low=0, high=5, size=np_shape).astype(dtype) |
| 137 | a = tvm.runtime.empty(shape=(length,), dtype=native_dtype, device=dev) |
| 138 | r = tvm.runtime.empty(shape=(length,), dtype=packed_dtype, device=dev) |
| 139 | b = tvm.runtime.empty(shape=(length,), dtype=native_dtype, device=dev) |
| 140 | a.copyfrom(a_np) |
| 141 | f(a, r, b) |
| 142 | tvm.testing.assert_allclose(a.numpy().astype("float16"), b.numpy().astype("float16")) |
| 143 | |
| 144 | |
| 145 | native_dtype, promoted_dtype, numpytype = tvm.testing.parameters( |
nothing calls this directly
no test coverage detected
searching dependent graphs…