MCPcopy
hub / github.com/hpcaitech/ColossalAI / check_cache_manager

Function check_cache_manager

tests/test_infer/test_kvcache_manager.py:69–163  ·  view source on GitHub ↗
(test_config)

Source from the content-addressed store, hash-verified

67 ],
68)
69def check_cache_manager(test_config):
70 disable_existing_loggers()
71
72 assert test_config["max_batch_size"] > 1
73
74 hidden_size = test_config.pop("hidden_size")
75 num_layers = test_config.pop("num_layers")
76 num_attention_heads = test_config.pop("num_attention_heads")
77 head_size = hidden_size // num_attention_heads
78 block_size = test_config["block_size"]
79 max_batch_size = test_config["max_batch_size"]
80 max_input_length = test_config["max_input_len"]
81 max_output_length = test_config["max_output_len"]
82
83 inference_config = InferenceConfig(**test_config)
84 model_config = LlamaConfig(
85 hidden_size=hidden_size,
86 num_hidden_layers=num_layers,
87 num_attention_heads=num_attention_heads,
88 )
89 cache_manager = KVCacheManager(inference_config, model_config)
90
91 num_blocks = cache_manager.total_num_blocks
92 assert num_blocks > 0
93 assert len(cache_manager._cache_blocks) == num_blocks
94 key_caches = cache_manager._kv_caches[0] # key caches for all the blocks in all the layers
95 assert len(key_caches) == num_layers
96 expected_kv_shape = (num_blocks, num_attention_heads, block_size, head_size)
97 assert key_caches[0].shape == expected_kv_shape
98 k_cache_block0, v_cache_block0 = cache_manager.get_physical_cache(0, 0)
99 expected_kv_block_shape = expected_kv_shape[1:]
100 assert k_cache_block0.shape == expected_kv_block_shape
101 assert v_cache_block0.shape == expected_kv_block_shape
102
103 max_blocks_per_seq = cache_manager.get_max_blocks_per_sequence()
104 block_tables = torch.tensor(
105 [[-1 for _ in range(max_blocks_per_seq)] for _ in range(test_config["max_batch_size"])], dtype=torch.int32
106 )
107 context_lengths = [random.randint(1, max_input_length) for _ in range(max_batch_size)]
108 cnt_blocks_used = 0
109 # Mock Prefill
110 for req_i in range(max_batch_size):
111 cur_seq_len = context_lengths[req_i]
112 cur_block_table = block_tables[req_i]
113 cache_manager.allocate_context_from_block_table(cur_block_table, cur_seq_len)
114 last_allocated_idx = (cur_seq_len - 1) // block_size
115 assert torch.all(cur_block_table[: last_allocated_idx + 1] >= 0)
116 cnt_blocks_used += torch.sum(cur_block_table >= 0).item()
117 assert cache_manager.num_available_blocks == num_blocks - cnt_blocks_used
118
119 # Mock Decoding
120 for req_i in range(max_batch_size):
121 context_length = context_lengths[req_i]
122 cur_output_length = random.randint(1, max_output_length)
123 cur_block_table = block_tables[req_i]
124 for _ in range(cur_output_length):
125 cache_manager.allocate_token_from_block_table(cur_block_table, context_length)
126 context_length += 1

Callers 1

run_distFunction · 0.85

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…