| 50 | |
| 51 | class TestTreeMask(unittest.TestCase): |
| 52 | def setUp(self): |
| 53 | # TODO(liuzichang): If set q_head=32 or bsz=128, some case will fail. |
| 54 | paddle.seed(0) |
| 55 | self.max_seq_len = 32768 |
| 56 | self.encoder_max_partition_size = self.max_seq_len |
| 57 | self.max_partition_size = self.max_seq_len |
| 58 | |
| 59 | self.max_dec_len = 1024 |
| 60 | self.bsz = 64 |
| 61 | self.run_time = 3 |
| 62 | self.warm_up = 1 |
| 63 | self.block_size = 64 |
| 64 | self.head_dim = 128 |
| 65 | self.num_q_head = 20 |
| 66 | self.num_kv_head = 4 |
| 67 | self.use_qknorm = True |
| 68 | self.dtype = "bfloat16" |
| 69 | |
| 70 | self.rope_3d = False |
| 71 | self.use_neox_rotary_style = False |
| 72 | self.CURRENT_Q = [None] |
| 73 | self.TOTAL_K = [] |
| 74 | self.TOTAL_V = [] |
| 75 | |
| 76 | # Initialize cache and block tables |
| 77 | block_num_per_seq = (self.max_seq_len + self.block_size - 1) // self.block_size |
| 78 | max_block_num = block_num_per_seq * self.bsz |
| 79 | cache_shape = ( |
| 80 | max_block_num, |
| 81 | self.num_kv_head, |
| 82 | self.block_size, |
| 83 | self.head_dim, |
| 84 | ) |
| 85 | |
| 86 | self.cache_k = paddle.zeros(shape=cache_shape).astype(self.dtype) |
| 87 | self.cache_v = paddle.zeros(shape=cache_shape).astype(self.dtype) |
| 88 | |
| 89 | self.block_tables = paddle.zeros(shape=(self.bsz, block_num_per_seq), dtype="int32") |
| 90 | |
| 91 | free_list = list(range(max_block_num - 1, -1, -1)) |
| 92 | |
| 93 | for i in range(self.bsz): |
| 94 | need_block_num = (self.max_seq_len + self.block_size - 1) // self.block_size |
| 95 | for j in range(need_block_num): |
| 96 | block_id = free_list.pop() |
| 97 | self.block_tables[i, j] = block_id |
| 98 | |
| 99 | def tearDown(self): |
| 100 | self.CURRENT_Q = [None] |