(
expert_num: int,
m: int,
n: int,
k: int,
topk: int,
dtype: torch.dtype,
test_count: int,
use_fp8_w8a8: bool,
is_up: bool,
block_shape,
num_fused_shared_experts: int,
**config,
)
| 52 | |
| 53 | @torch.no_grad() |
| 54 | def test_kernel( |
| 55 | expert_num: int, |
| 56 | m: int, |
| 57 | n: int, |
| 58 | k: int, |
| 59 | topk: int, |
| 60 | dtype: torch.dtype, |
| 61 | test_count: int, |
| 62 | use_fp8_w8a8: bool, |
| 63 | is_up: bool, |
| 64 | block_shape, |
| 65 | num_fused_shared_experts: int, |
| 66 | **config, |
| 67 | ): |
| 68 | set_seed() |
| 69 | input_tuples = [] |
| 70 | |
| 71 | a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 |
| 72 | w1_scale = w2_scale = None |
| 73 | if num_fused_shared_experts > 0: |
| 74 | expert_num += num_fused_shared_experts |
| 75 | |
| 76 | if use_fp8_w8a8: |
| 77 | init_dtype = dtype |
| 78 | w1 = torch.randn(expert_num, 2 * n, k, dtype=init_dtype).cuda() |
| 79 | w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=init_dtype).cuda() |
| 80 | w1 = w1.to(torch.float8_e4m3fn) |
| 81 | w2 = w2.to(torch.float8_e4m3fn) |
| 82 | |
| 83 | if block_shape is None: |
| 84 | w1_scale = torch.randn(expert_num, dtype=torch.float32).cuda() |
| 85 | w2_scale = torch.randn(expert_num, dtype=torch.float32).cuda() |
| 86 | else: |
| 87 | block_n, block_k = block_shape[0], block_shape[1] |
| 88 | n_tiles_w1 = (2 * n + block_n - 1) // block_n |
| 89 | n_tiles_w2 = (k + block_n - 1) // block_n |
| 90 | k_tiles_w1 = (k + block_k - 1) // block_k |
| 91 | k_tiles_w2 = (2 * n // 2 + block_k - 1) // block_k |
| 92 | w1_scale = torch.rand((expert_num, n_tiles_w1, k_tiles_w1), dtype=torch.float32).cuda() |
| 93 | w2_scale = torch.rand((expert_num, n_tiles_w2, k_tiles_w2), dtype=torch.float32).cuda() |
| 94 | else: |
| 95 | w1 = torch.randn(expert_num, 2 * n, k, dtype=dtype).cuda() |
| 96 | w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=dtype).cuda() |
| 97 | |
| 98 | rnd_logics = torch.randn(m, expert_num - num_fused_shared_experts, device="cuda") |
| 99 | topk_values, topk_ids = torch.topk(rnd_logics, topk, dim=1) |
| 100 | if num_fused_shared_experts > 0: |
| 101 | # 存在融合共享专家的时候,需要pad 共享专家对应的id 到topk_ids 中 |
| 102 | pad_topk_ids = ( |
| 103 | torch.arange( |
| 104 | start=expert_num - num_fused_shared_experts, end=expert_num, step=1, dtype=topk_ids.dtype, device="cuda" |
| 105 | ) |
| 106 | .view(1, num_fused_shared_experts) |
| 107 | .repeat(topk_ids.shape[0], 1) |
| 108 | ) |
| 109 | topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) |
| 110 | topk_weights = torch.randn((m, topk + num_fused_shared_experts), device="cuda", dtype=dtype) / 10 |
| 111 |
no test coverage detected