(
m: int,
topk_num: int,
hidden_dim: int,
dtype: torch.dtype,
test_count: int,
**config,
)
| 28 | |
| 29 | @torch.no_grad() |
| 30 | def test_kernel( |
| 31 | m: int, |
| 32 | topk_num: int, |
| 33 | hidden_dim: int, |
| 34 | dtype: torch.dtype, |
| 35 | test_count: int, |
| 36 | **config, |
| 37 | ): |
| 38 | set_seed() |
| 39 | input_tuples = [] |
| 40 | |
| 41 | input = torch.randn((m, topk_num, hidden_dim), device="cuda", dtype=dtype) / 10 |
| 42 | output = torch.randn((m, hidden_dim), device="cuda", dtype=dtype) |
| 43 | |
| 44 | for _ in range(test_count): |
| 45 | input_tuples.append((input.clone(), output.clone())) |
| 46 | |
| 47 | # warm_up |
| 48 | moe_sum_reduce(input, output, run_config=config) |
| 49 | |
| 50 | graph = torch.cuda.CUDAGraph() |
| 51 | |
| 52 | with torch.cuda.graph(graph): |
| 53 | for index in range(test_count): |
| 54 | input, output = input_tuples[index] |
| 55 | moe_sum_reduce(input, output, run_config=config) |
| 56 | |
| 57 | graph.replay() |
| 58 | |
| 59 | torch.cuda.synchronize() |
| 60 | start = time.time() |
| 61 | graph.replay() |
| 62 | torch.cuda.synchronize() |
| 63 | |
| 64 | cost_time = (time.time() - start) * 1000 |
| 65 | |
| 66 | logger.info(str(config)) |
| 67 | logger.info(f"bf16 {m} cost time: {cost_time} ms") |
| 68 | return cost_time |
| 69 | |
| 70 | |
| 71 | def worker( |
no test coverage detected