MCPcopy
hub / github.com/ModelTC/LightLLM / test_kernel

Function test_kernel

test/kernel/fuse_moe_tuning.py:54–226  ·  view source on GitHub ↗
(
    expert_num: int,
    m: int,
    n: int,
    k: int,
    topk: int,
    dtype: torch.dtype,
    test_count: int,
    use_fp8_w8a8: bool,
    is_up: bool,
    block_shape,
    num_fused_shared_experts: int,
    **config,
)

Source from the content-addressed store, hash-verified

52
53@torch.no_grad()
54def test_kernel(
55 expert_num: int,
56 m: int,
57 n: int,
58 k: int,
59 topk: int,
60 dtype: torch.dtype,
61 test_count: int,
62 use_fp8_w8a8: bool,
63 is_up: bool,
64 block_shape,
65 num_fused_shared_experts: int,
66 **config,
67):
68 set_seed()
69 input_tuples = []
70
71 a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
72 w1_scale = w2_scale = None
73 if num_fused_shared_experts > 0:
74 expert_num += num_fused_shared_experts
75
76 if use_fp8_w8a8:
77 init_dtype = dtype
78 w1 = torch.randn(expert_num, 2 * n, k, dtype=init_dtype).cuda()
79 w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=init_dtype).cuda()
80 w1 = w1.to(torch.float8_e4m3fn)
81 w2 = w2.to(torch.float8_e4m3fn)
82
83 if block_shape is None:
84 w1_scale = torch.randn(expert_num, dtype=torch.float32).cuda()
85 w2_scale = torch.randn(expert_num, dtype=torch.float32).cuda()
86 else:
87 block_n, block_k = block_shape[0], block_shape[1]
88 n_tiles_w1 = (2 * n + block_n - 1) // block_n
89 n_tiles_w2 = (k + block_n - 1) // block_n
90 k_tiles_w1 = (k + block_k - 1) // block_k
91 k_tiles_w2 = (2 * n // 2 + block_k - 1) // block_k
92 w1_scale = torch.rand((expert_num, n_tiles_w1, k_tiles_w1), dtype=torch.float32).cuda()
93 w2_scale = torch.rand((expert_num, n_tiles_w2, k_tiles_w2), dtype=torch.float32).cuda()
94 else:
95 w1 = torch.randn(expert_num, 2 * n, k, dtype=dtype).cuda()
96 w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=dtype).cuda()
97
98 rnd_logics = torch.randn(m, expert_num - num_fused_shared_experts, device="cuda")
99 topk_values, topk_ids = torch.topk(rnd_logics, topk, dim=1)
100 if num_fused_shared_experts > 0:
101 # 存在融合共享专家的时候,需要pad 共享专家对应的id 到topk_ids 中
102 pad_topk_ids = (
103 torch.arange(
104 start=expert_num - num_fused_shared_experts, end=expert_num, step=1, dtype=topk_ids.dtype, device="cuda"
105 )
106 .view(1, num_fused_shared_experts)
107 .repeat(topk_ids.shape[0], 1)
108 )
109 topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1)
110 topk_weights = torch.randn((m, topk + num_fused_shared_experts), device="cuda", dtype=dtype) / 10
111

Callers 1

workerFunction · 0.70

Calls 7

moe_alignFunction · 0.90
moe_align1Function · 0.90
grouped_matmulFunction · 0.90
emptyMethod · 0.80
replayMethod · 0.80
set_seedFunction · 0.70
cudaMethod · 0.45

Tested by

no test coverage detected