| 283 | |
| 284 | |
| 285 | class KeyValueCacheParams: |
| 286 | |
| 287 | def __init__(self, |
| 288 | past_key_value: List[Tensor] = None, |
| 289 | host_past_key_value_lengths: Tensor = None, |
| 290 | host_max_attention_window_sizes: Tensor = None, |
| 291 | host_sink_token_length: Tensor = None, |
| 292 | kv_cache_block_offsets: Tensor = None, |
| 293 | host_kv_cache_block_offsets: Tensor = None, |
| 294 | host_kv_cache_pool_pointers: Tensor = None, |
| 295 | host_kv_cache_pool_mapping: Tensor = None, |
| 296 | cache_indirection: Tensor = None, |
| 297 | past_key_value_length: Tensor = None, |
| 298 | cross_kv_cache_block_offsets: Tensor = None, |
| 299 | host_cross_kv_cache_block_offsets: Tensor = None, |
| 300 | host_cross_kv_cache_pool_pointers: Tensor = None, |
| 301 | host_cross_kv_cache_pool_mapping: Tensor = None): |
| 302 | self.past_key_value = past_key_value |
| 303 | self.host_past_key_value_lengths = host_past_key_value_lengths |
| 304 | self.host_max_attention_window_sizes = host_max_attention_window_sizes |
| 305 | self.host_sink_token_length = host_sink_token_length |
| 306 | self.kv_cache_block_offsets = kv_cache_block_offsets |
| 307 | self.host_kv_cache_block_offsets = host_kv_cache_block_offsets |
| 308 | self.host_kv_cache_pool_pointers = host_kv_cache_pool_pointers |
| 309 | self.host_kv_cache_pool_mapping = host_kv_cache_pool_mapping |
| 310 | self.cross_kv_cache_block_offsets = cross_kv_cache_block_offsets |
| 311 | self.host_cross_kv_cache_block_offsets = host_cross_kv_cache_block_offsets |
| 312 | self.host_cross_kv_cache_pool_pointers = host_cross_kv_cache_pool_pointers |
| 313 | self.host_cross_kv_cache_pool_mapping = host_cross_kv_cache_pool_mapping |
| 314 | self.cache_indirection = cache_indirection |
| 315 | # self.past_key_value_length = past_key_value_length |
| 316 | |
| 317 | def get_first_past_key_value(self): |
| 318 | if self.past_key_value is None: |
| 319 | return None |
| 320 | return self.past_key_value[0] |
| 321 | |
| 322 | def fill_none_tensor_list(self, list_size): |
| 323 | if self.past_key_value is None: |
| 324 | self.past_key_value = tuple([None] * list_size) |
| 325 | |
| 326 | def is_valid(self, gpt_attention_plugin): |
| 327 | if gpt_attention_plugin: |
| 328 | if self.host_past_key_value_lengths is None: |
| 329 | return False |
| 330 | if self.host_max_attention_window_sizes is None: |
| 331 | return False |
| 332 | if self.host_sink_token_length is None: |
| 333 | return False |
| 334 | if self.cache_indirection is None: |
| 335 | return False |
| 336 | |
| 337 | return True |
| 338 | |
| 339 | |
| 340 | class BlockSparseAttnParams: |
no outgoing calls