Thread-safe LRU cache with prefix matching for prompt KV caches. This cache stores KV caches keyed by token sequences and supports: - Exact match: Return the cache for the exact token sequence - Shorter prefix match: Return a cache for a prefix of the tokens - Longer prefix mat
| 40 | |
| 41 | |
| 42 | class ThreadSafeLRUPromptCache: |
| 43 | """ |
| 44 | Thread-safe LRU cache with prefix matching for prompt KV caches. |
| 45 | |
| 46 | This cache stores KV caches keyed by token sequences and supports: |
| 47 | - Exact match: Return the cache for the exact token sequence |
| 48 | - Shorter prefix match: Return a cache for a prefix of the tokens |
| 49 | - Longer prefix match: If a longer sequence is cached and can be trimmed |
| 50 | - LRU eviction: When max_size is exceeded, evict least recently used |
| 51 | |
| 52 | Thread safety is provided via a threading.Lock that protects all |
| 53 | cache operations. |
| 54 | |
| 55 | Args: |
| 56 | max_size: Maximum number of cache entries (default: 10) |
| 57 | can_trim_fn: Optional function to check if a cache can be trimmed |
| 58 | trim_fn: Optional function to trim a cache |
| 59 | """ |
| 60 | |
| 61 | def __init__( |
| 62 | self, |
| 63 | max_size: int = 10, |
| 64 | can_trim_fn: Optional[Any] = None, |
| 65 | trim_fn: Optional[Any] = None, |
| 66 | ): |
| 67 | self.max_size = max_size |
| 68 | self._cache = {} |
| 69 | self._lru = deque() |
| 70 | self._lock = threading.Lock() |
| 71 | |
| 72 | # Optional trim functions (for longer prefix reuse) |
| 73 | self._can_trim_fn = can_trim_fn |
| 74 | self._trim_fn = trim_fn |
| 75 | |
| 76 | def _search(self, model, tokens: List[int]) -> SearchResult: |
| 77 | """ |
| 78 | Search the cache for a prompt cache. Return exact or close match. |
| 79 | |
| 80 | The cache is organized as a trie where each node is keyed by a token. |
| 81 | This allows efficient prefix matching. |
| 82 | """ |
| 83 | if model not in self._cache: |
| 84 | return SearchResult(model, None, None, None, 0) |
| 85 | |
| 86 | current = self._cache[model] |
| 87 | last_cache_index = -1 |
| 88 | index = 0 |
| 89 | |
| 90 | # Traverse the trie following the token sequence |
| 91 | while index < len(tokens) and tokens[index] in current: |
| 92 | current = current[tokens[index]] |
| 93 | if "cache" in current: |
| 94 | last_cache_index = index |
| 95 | index += 1 |
| 96 | |
| 97 | # Exact match - no need to search for longer or shorter caches |
| 98 | if last_cache_index == len(tokens) - 1: |
| 99 | return SearchResult(model, tuple(tokens), None, None, 0) |