MCPcopy Index your code
hub / github.com/ModelTC/LightLLM / test_kernel

Function test_kernel

test/kernel/moe_sum_reduce_tuning_bf16.py:30–68  ·  view source on GitHub ↗
(
    m: int,
    topk_num: int,
    hidden_dim: int,
    dtype: torch.dtype,
    test_count: int,
    **config,
)

Source from the content-addressed store, hash-verified

28
29@torch.no_grad()
30def test_kernel(
31 m: int,
32 topk_num: int,
33 hidden_dim: int,
34 dtype: torch.dtype,
35 test_count: int,
36 **config,
37):
38 set_seed()
39 input_tuples = []
40
41 input = torch.randn((m, topk_num, hidden_dim), device="cuda", dtype=dtype) / 10
42 output = torch.randn((m, hidden_dim), device="cuda", dtype=dtype)
43
44 for _ in range(test_count):
45 input_tuples.append((input.clone(), output.clone()))
46
47 # warm_up
48 moe_sum_reduce(input, output, run_config=config)
49
50 graph = torch.cuda.CUDAGraph()
51
52 with torch.cuda.graph(graph):
53 for index in range(test_count):
54 input, output = input_tuples[index]
55 moe_sum_reduce(input, output, run_config=config)
56
57 graph.replay()
58
59 torch.cuda.synchronize()
60 start = time.time()
61 graph.replay()
62 torch.cuda.synchronize()
63
64 cost_time = (time.time() - start) * 1000
65
66 logger.info(str(config))
67 logger.info(f"bf16 {m} cost time: {cost_time} ms")
68 return cost_time
69
70
71def worker(

Callers 1

workerFunction · 0.70

Calls 3

moe_sum_reduceFunction · 0.90
replayMethod · 0.80
set_seedFunction · 0.70

Tested by

no test coverage detected