(
M: int,
N: int,
K: int,
block_size: int,
dtype: torch.dtype,
test_count: int = 20,
**run_config,
)
| 27 | |
| 28 | @torch.no_grad() |
| 29 | def test_fp8_block_gemm( |
| 30 | M: int, |
| 31 | N: int, |
| 32 | K: int, |
| 33 | block_size: int, |
| 34 | dtype: torch.dtype, |
| 35 | test_count: int = 20, |
| 36 | **run_config, |
| 37 | ): |
| 38 | set_seed() |
| 39 | |
| 40 | input_tuples = [] |
| 41 | for _ in range(test_count): |
| 42 | A = torch.randn((M, K), dtype=torch.float32).cuda().to(torch.float8_e4m3fn) # Activation |
| 43 | B = torch.randn((K, N), dtype=torch.float32).cuda().to(torch.float8_e4m3fn) # Weight |
| 44 | Ascale = torch.ones((M, (K + block_size - 1) // block_size)).cuda() |
| 45 | Bscale = torch.ones(((K + block_size - 1) // block_size, (N + block_size - 1) // block_size)).cuda() |
| 46 | C = torch.randn((M, N), dtype=dtype).cuda() # weight |
| 47 | input_tuples.append((A, B, Ascale, Bscale, C)) |
| 48 | w8a8_block_fp8_matmul(A, B, Ascale, Bscale, C, (block_size, block_size), dtype, run_config=run_config) |
| 49 | |
| 50 | graph = torch.cuda.CUDAGraph() |
| 51 | with torch.cuda.graph(graph): |
| 52 | for index in range(test_count): |
| 53 | A, B, Ascale, Bscale, C = input_tuples[index] |
| 54 | w8a8_block_fp8_matmul( |
| 55 | A, |
| 56 | B, |
| 57 | Ascale, |
| 58 | Bscale, |
| 59 | C, |
| 60 | (block_size, block_size), |
| 61 | run_config=run_config, |
| 62 | ) |
| 63 | |
| 64 | graph.replay() |
| 65 | torch.cuda.synchronize() |
| 66 | start = time.time() |
| 67 | graph.replay() |
| 68 | torch.cuda.synchronize() |
| 69 | cost_time = (time.time() - start) * 1000 |
| 70 | logger.info(f"fp8 mm {M} {N} {K} block {block_size} cost time: {cost_time} ms") |
| 71 | return cost_time |
| 72 | |
| 73 | |
| 74 | def worker( |
no test coverage detected