Search the cache for a prompt cache. Return exact or close match. The cache is organized as a trie where each node is keyed by a token. This allows efficient prefix matching.
(self, model, tokens: List[int])
| 74 | self._trim_fn = trim_fn |
| 75 | |
| 76 | def _search(self, model, tokens: List[int]) -> SearchResult: |
| 77 | """ |
| 78 | Search the cache for a prompt cache. Return exact or close match. |
| 79 | |
| 80 | The cache is organized as a trie where each node is keyed by a token. |
| 81 | This allows efficient prefix matching. |
| 82 | """ |
| 83 | if model not in self._cache: |
| 84 | return SearchResult(model, None, None, None, 0) |
| 85 | |
| 86 | current = self._cache[model] |
| 87 | last_cache_index = -1 |
| 88 | index = 0 |
| 89 | |
| 90 | # Traverse the trie following the token sequence |
| 91 | while index < len(tokens) and tokens[index] in current: |
| 92 | current = current[tokens[index]] |
| 93 | if "cache" in current: |
| 94 | last_cache_index = index |
| 95 | index += 1 |
| 96 | |
| 97 | # Exact match - no need to search for longer or shorter caches |
| 98 | if last_cache_index == len(tokens) - 1: |
| 99 | return SearchResult(model, tuple(tokens), None, None, 0) |
| 100 | |
| 101 | # Find the shorter cache (a prefix that has a cache) |
| 102 | # Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior. |
| 103 | # Single-token prefixes are not matched, which allows longer cached |
| 104 | # sequences to be preferred for trimming. This is acceptable because |
| 105 | # real prompts with chat templates are always many tokens. |
| 106 | shorter = None |
| 107 | if last_cache_index > 0: |
| 108 | shorter = tuple(tokens[: last_cache_index + 1]) |
| 109 | |
| 110 | # Check for caches that are longer than our token sequence |
| 111 | longer = None |
| 112 | common_prefix = index |
| 113 | if index > 0 and last_cache_index <= 0: |
| 114 | best = None |
| 115 | stack = [(current, [])] |
| 116 | while stack: |
| 117 | current, extra = stack.pop() |
| 118 | if "cache" in current: |
| 119 | if best is None or len(extra) < len(best): |
| 120 | best = extra |
| 121 | else: |
| 122 | for tok in current: |
| 123 | stack.append((current[tok], extra + [tok])) |
| 124 | if best is not None: |
| 125 | longer = tuple(tokens[:index] + best) |
| 126 | |
| 127 | return SearchResult(model, None, shorter, longer, common_prefix) |
| 128 | |
| 129 | def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry: |
| 130 | """Get a cache entry by traversing the trie.""" |
no test coverage detected