MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / KeyValueCacheParams

Class KeyValueCacheParams

tensorrt_llm/layers/attention.py:285–337  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

283
284
285class KeyValueCacheParams:
286
287 def __init__(self,
288 past_key_value: List[Tensor] = None,
289 host_past_key_value_lengths: Tensor = None,
290 host_max_attention_window_sizes: Tensor = None,
291 host_sink_token_length: Tensor = None,
292 kv_cache_block_offsets: Tensor = None,
293 host_kv_cache_block_offsets: Tensor = None,
294 host_kv_cache_pool_pointers: Tensor = None,
295 host_kv_cache_pool_mapping: Tensor = None,
296 cache_indirection: Tensor = None,
297 past_key_value_length: Tensor = None,
298 cross_kv_cache_block_offsets: Tensor = None,
299 host_cross_kv_cache_block_offsets: Tensor = None,
300 host_cross_kv_cache_pool_pointers: Tensor = None,
301 host_cross_kv_cache_pool_mapping: Tensor = None):
302 self.past_key_value = past_key_value
303 self.host_past_key_value_lengths = host_past_key_value_lengths
304 self.host_max_attention_window_sizes = host_max_attention_window_sizes
305 self.host_sink_token_length = host_sink_token_length
306 self.kv_cache_block_offsets = kv_cache_block_offsets
307 self.host_kv_cache_block_offsets = host_kv_cache_block_offsets
308 self.host_kv_cache_pool_pointers = host_kv_cache_pool_pointers
309 self.host_kv_cache_pool_mapping = host_kv_cache_pool_mapping
310 self.cross_kv_cache_block_offsets = cross_kv_cache_block_offsets
311 self.host_cross_kv_cache_block_offsets = host_cross_kv_cache_block_offsets
312 self.host_cross_kv_cache_pool_pointers = host_cross_kv_cache_pool_pointers
313 self.host_cross_kv_cache_pool_mapping = host_cross_kv_cache_pool_mapping
314 self.cache_indirection = cache_indirection
315 # self.past_key_value_length = past_key_value_length
316
317 def get_first_past_key_value(self):
318 if self.past_key_value is None:
319 return None
320 return self.past_key_value[0]
321
322 def fill_none_tensor_list(self, list_size):
323 if self.past_key_value is None:
324 self.past_key_value = tuple([None] * list_size)
325
326 def is_valid(self, gpt_attention_plugin):
327 if gpt_attention_plugin:
328 if self.host_past_key_value_lengths is None:
329 return False
330 if self.host_max_attention_window_sizes is None:
331 return False
332 if self.host_sink_token_length is None:
333 return False
334 if self.cache_indirection is None:
335 return False
336
337 return True
338
339
340class BlockSparseAttnParams:

Callers 13

test_lora_attentionMethod · 0.90
test_attentionMethod · 0.90
forwardMethod · 0.90
prepare_inputsMethod · 0.90
forwardMethod · 0.90
prepare_inputsMethod · 0.90
forwardMethod · 0.90
forwardMethod · 0.90
prepare_inputsMethod · 0.90
forwardMethod · 0.85
prepare_inputsMethod · 0.85
forwardMethod · 0.85

Calls

no outgoing calls

Tested by 2

test_lora_attentionMethod · 0.72
test_attentionMethod · 0.72