MCPcopy
hub / github.com/ModelTC/LightLLM / test_decode_attentions

Function test_decode_attentions

test/kernel/deepseekv2_gqa_decode_tuning.py:29–127  ·  view source on GitHub ↗
(
    q_nope_shape: List[int],
    q_rope_shape: List[int],
    kv_nope_shape: List[int],
    kv_rope_shape: List[int],
    test_seq_len: int,
    dtype: torch.dtype,
    test_count: int = 20,
    **run_config,
)

Source from the content-addressed store, hash-verified

27
28@torch.no_grad()
29def test_decode_attentions(
30 q_nope_shape: List[int],
31 q_rope_shape: List[int],
32 kv_nope_shape: List[int],
33 kv_rope_shape: List[int],
34 test_seq_len: int,
35 dtype: torch.dtype,
36 test_count: int = 20,
37 **run_config,
38):
39 set_seed()
40 tmp_class = type("TestObj", (object,), {})
41 infer_state = tmp_class()
42 infer_state.batch_size = q_nope_shape[0]
43 infer_state.max_len_in_batch = test_seq_len
44 infer_state.req_manager = tmp_class()
45 infer_state.req_manager.req_to_token_indexs = torch.zeros(
46 (infer_state.batch_size, infer_state.max_len_in_batch), dtype=torch.int32, device="cuda"
47 )
48 infer_state.req_manager.req_to_token_indexs.view(-1)[:] = torch.arange(
49 0, infer_state.batch_size * infer_state.max_len_in_batch, step=1, dtype=torch.int32
50 ).cuda()
51 infer_state.b_req_idx = torch.arange(0, infer_state.batch_size, step=1, dtype=torch.int32).cuda()
52 infer_state.b_seq_len = torch.full((infer_state.batch_size,), fill_value=test_seq_len, dtype=torch.int32).cuda()
53
54 input_tuples = []
55 for _ in range(test_count):
56 q_nope = torch.randn(q_nope_shape, device="cuda", dtype=dtype) / 10
57 q_rope = torch.randn(q_rope_shape, device="cuda", dtype=dtype) / 10
58 kv_buffer_shape = [
59 (test_seq_len + 10) * infer_state.batch_size,
60 kv_nope_shape[1],
61 kv_nope_shape[2] + kv_rope_shape[2],
62 ]
63 kv_buffer = torch.randn(kv_buffer_shape, device="cuda", dtype=dtype) / 10
64
65 kv_nope = kv_buffer[:, :, 0 : kv_nope_shape[2]]
66 kv_rope = kv_buffer[:, :, kv_nope_shape[2] :]
67 o_tensor = torch.empty_like(q_nope)
68 input_tuples.append((q_nope, q_rope, kv_buffer, kv_nope, kv_rope, o_tensor))
69
70 tensor_dict = {}
71
72 def inner_alloc_func(shape, dtype=torch.float32, device="cuda"):
73 shape = tuple(shape)
74 if shape not in tensor_dict:
75 ans = torch.empty(shape, dtype=dtype, device=device)
76 tensor_dict[shape] = ans
77 return ans
78 else:
79 return tensor_dict[shape]
80
81 gqa_token_decode_attention_flash_decoding(
82 q_nope,
83 q_rope,
84 kv_nope,
85 kv_rope,
86 infer_state,

Callers 1

workerFunction · 0.70

Calls 4

replayMethod · 0.80
set_seedFunction · 0.70
cudaMethod · 0.45

Tested by

no test coverage detected