(test_config)
| 67 | ], |
| 68 | ) |
| 69 | def check_cache_manager(test_config): |
| 70 | disable_existing_loggers() |
| 71 | |
| 72 | assert test_config["max_batch_size"] > 1 |
| 73 | |
| 74 | hidden_size = test_config.pop("hidden_size") |
| 75 | num_layers = test_config.pop("num_layers") |
| 76 | num_attention_heads = test_config.pop("num_attention_heads") |
| 77 | head_size = hidden_size // num_attention_heads |
| 78 | block_size = test_config["block_size"] |
| 79 | max_batch_size = test_config["max_batch_size"] |
| 80 | max_input_length = test_config["max_input_len"] |
| 81 | max_output_length = test_config["max_output_len"] |
| 82 | |
| 83 | inference_config = InferenceConfig(**test_config) |
| 84 | model_config = LlamaConfig( |
| 85 | hidden_size=hidden_size, |
| 86 | num_hidden_layers=num_layers, |
| 87 | num_attention_heads=num_attention_heads, |
| 88 | ) |
| 89 | cache_manager = KVCacheManager(inference_config, model_config) |
| 90 | |
| 91 | num_blocks = cache_manager.total_num_blocks |
| 92 | assert num_blocks > 0 |
| 93 | assert len(cache_manager._cache_blocks) == num_blocks |
| 94 | key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers |
| 95 | assert len(key_caches) == num_layers |
| 96 | expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size) |
| 97 | assert key_caches[0].shape == expected_kv_shape |
| 98 | k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0) |
| 99 | expected_kv_block_shape = expected_kv_shape[1:] |
| 100 | assert k_cache_block0.shape == expected_kv_block_shape |
| 101 | assert v_cache_block0.shape == expected_kv_block_shape |
| 102 | |
| 103 | max_blocks_per_seq = cache_manager.get_max_blocks_per_sequence() |
| 104 | block_tables = torch.tensor( |
| 105 | [[-1 for _ in range(max_blocks_per_seq)] for _ in range(test_config["max_batch_size"])], dtype=torch.int32 |
| 106 | ) |
| 107 | context_lengths = [random.randint(1, max_input_length) for _ in range(max_batch_size)] |
| 108 | cnt_blocks_used = 0 |
| 109 | # Mock Prefill |
| 110 | for req_i in range(max_batch_size): |
| 111 | cur_seq_len = context_lengths[req_i] |
| 112 | cur_block_table = block_tables[req_i] |
| 113 | cache_manager.allocate_context_from_block_table(cur_block_table, cur_seq_len) |
| 114 | last_allocated_idx = (cur_seq_len - 1) // block_size |
| 115 | assert torch.all(cur_block_table[: last_allocated_idx + 1] >= 0) |
| 116 | cnt_blocks_used += torch.sum(cur_block_table >= 0).item() |
| 117 | assert cache_manager.num_available_blocks == num_blocks - cnt_blocks_used |
| 118 | |
| 119 | # Mock Decoding |
| 120 | for req_i in range(max_batch_size): |
| 121 | context_length = context_lengths[req_i] |
| 122 | cur_output_length = random.randint(1, max_output_length) |
| 123 | cur_block_table = block_tables[req_i] |
| 124 | for _ in range(cur_output_length): |
| 125 | cache_manager.allocate_token_from_block_table(cur_block_table, context_length) |
| 126 | context_length += 1 |
no test coverage detected
searching dependent graphs…