MCPcopy
hub / github.com/ModelTC/LightLLM / test_fp8_block_gemm

Function test_fp8_block_gemm

test/kernel/deepseekv3_fp8_block_gemm_tuning.py:29–71  ·  view source on GitHub ↗
(
    M: int,
    N: int,
    K: int,
    block_size: int,
    dtype: torch.dtype,
    test_count: int = 20,
    **run_config,
)

Source from the content-addressed store, hash-verified

27
28@torch.no_grad()
29def test_fp8_block_gemm(
30 M: int,
31 N: int,
32 K: int,
33 block_size: int,
34 dtype: torch.dtype,
35 test_count: int = 20,
36 **run_config,
37):
38 set_seed()
39
40 input_tuples = []
41 for _ in range(test_count):
42 A = torch.randn((M, K), dtype=torch.float32).cuda().to(torch.float8_e4m3fn) # Activation
43 B = torch.randn((K, N), dtype=torch.float32).cuda().to(torch.float8_e4m3fn) # Weight
44 Ascale = torch.ones((M, (K + block_size - 1) // block_size)).cuda()
45 Bscale = torch.ones(((K + block_size - 1) // block_size, (N + block_size - 1) // block_size)).cuda()
46 C = torch.randn((M, N), dtype=dtype).cuda() # weight
47 input_tuples.append((A, B, Ascale, Bscale, C))
48 w8a8_block_fp8_matmul(A, B, Ascale, Bscale, C, (block_size, block_size), dtype, run_config=run_config)
49
50 graph = torch.cuda.CUDAGraph()
51 with torch.cuda.graph(graph):
52 for index in range(test_count):
53 A, B, Ascale, Bscale, C = input_tuples[index]
54 w8a8_block_fp8_matmul(
55 A,
56 B,
57 Ascale,
58 Bscale,
59 C,
60 (block_size, block_size),
61 run_config=run_config,
62 )
63
64 graph.replay()
65 torch.cuda.synchronize()
66 start = time.time()
67 graph.replay()
68 torch.cuda.synchronize()
69 cost_time = (time.time() - start) * 1000
70 logger.info(f"fp8 mm {M} {N} {K} block {block_size} cost time: {cost_time} ms")
71 return cost_time
72
73
74def worker(

Callers 1

workerFunction · 0.85

Calls 4

w8a8_block_fp8_matmulFunction · 0.90
replayMethod · 0.80
set_seedFunction · 0.70
cudaMethod · 0.45

Tested by

no test coverage detected