(
q_nope_shape: List[int],
q_rope_shape: List[int],
kv_nope_shape: List[int],
kv_rope_shape: List[int],
test_seq_len: int,
dtype: torch.dtype,
test_count: int = 20,
**run_config,
)
| 27 | |
| 28 | @torch.no_grad() |
| 29 | def test_decode_attentions( |
| 30 | q_nope_shape: List[int], |
| 31 | q_rope_shape: List[int], |
| 32 | kv_nope_shape: List[int], |
| 33 | kv_rope_shape: List[int], |
| 34 | test_seq_len: int, |
| 35 | dtype: torch.dtype, |
| 36 | test_count: int = 20, |
| 37 | **run_config, |
| 38 | ): |
| 39 | set_seed() |
| 40 | tmp_class = type("TestObj", (object,), {}) |
| 41 | infer_state = tmp_class() |
| 42 | infer_state.batch_size = q_nope_shape[0] |
| 43 | infer_state.max_len_in_batch = test_seq_len |
| 44 | infer_state.req_manager = tmp_class() |
| 45 | infer_state.req_manager.req_to_token_indexs = torch.zeros( |
| 46 | (infer_state.batch_size, infer_state.max_len_in_batch), dtype=torch.int32, device="cuda" |
| 47 | ) |
| 48 | infer_state.req_manager.req_to_token_indexs.view(-1)[:] = torch.arange( |
| 49 | 0, infer_state.batch_size * infer_state.max_len_in_batch, step=1, dtype=torch.int32 |
| 50 | ).cuda() |
| 51 | infer_state.b_req_idx = torch.arange(0, infer_state.batch_size, step=1, dtype=torch.int32).cuda() |
| 52 | infer_state.b_seq_len = torch.full((infer_state.batch_size,), fill_value=test_seq_len, dtype=torch.int32).cuda() |
| 53 | |
| 54 | input_tuples = [] |
| 55 | for _ in range(test_count): |
| 56 | q_nope = torch.randn(q_nope_shape, device="cuda", dtype=dtype) / 10 |
| 57 | q_rope = torch.randn(q_rope_shape, device="cuda", dtype=dtype) / 10 |
| 58 | kv_buffer_shape = [ |
| 59 | (test_seq_len + 10) * infer_state.batch_size, |
| 60 | kv_nope_shape[1], |
| 61 | kv_nope_shape[2] + kv_rope_shape[2], |
| 62 | ] |
| 63 | kv_buffer = torch.randn(kv_buffer_shape, device="cuda", dtype=dtype) / 10 |
| 64 | |
| 65 | kv_nope = kv_buffer[:, :, 0 : kv_nope_shape[2]] |
| 66 | kv_rope = kv_buffer[:, :, kv_nope_shape[2] :] |
| 67 | o_tensor = torch.empty_like(q_nope) |
| 68 | input_tuples.append((q_nope, q_rope, kv_buffer, kv_nope, kv_rope, o_tensor)) |
| 69 | |
| 70 | tensor_dict = {} |
| 71 | |
| 72 | def inner_alloc_func(shape, dtype=torch.float32, device="cuda"): |
| 73 | shape = tuple(shape) |
| 74 | if shape not in tensor_dict: |
| 75 | ans = torch.empty(shape, dtype=dtype, device=device) |
| 76 | tensor_dict[shape] = ans |
| 77 | return ans |
| 78 | else: |
| 79 | return tensor_dict[shape] |
| 80 | |
| 81 | gqa_token_decode_attention_flash_decoding( |
| 82 | q_nope, |
| 83 | q_rope, |
| 84 | kv_nope, |
| 85 | kv_rope, |
| 86 | infer_state, |
no test coverage detected