Generates text based on the given prompt and sampling parameters using MLX. Uses thread-safe LRU prompt cache for efficient prefix reuse across requests. Args: request: The predict request. context: The gRPC context. Returns: ba
(self, request, context)
| 126 | return backend_pb2.Result(message="MLX model loaded successfully", success=True) |
| 127 | |
| 128 | async def Predict(self, request, context): |
| 129 | """ |
| 130 | Generates text based on the given prompt and sampling parameters using MLX. |
| 131 | |
| 132 | Uses thread-safe LRU prompt cache for efficient prefix reuse across requests. |
| 133 | |
| 134 | Args: |
| 135 | request: The predict request. |
| 136 | context: The gRPC context. |
| 137 | |
| 138 | Returns: |
| 139 | backend_pb2.Reply: The predict result. |
| 140 | """ |
| 141 | prompt_cache = None |
| 142 | cache_key = None |
| 143 | |
| 144 | try: |
| 145 | # Prepare the prompt and tokenize for cache key |
| 146 | prompt_text = self._prepare_prompt(request) |
| 147 | cache_key = self._get_tokens_from_prompt(prompt_text) |
| 148 | |
| 149 | # Fetch nearest cache (exact, shorter prefix, or create new) |
| 150 | prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( |
| 151 | self.model_key, cache_key |
| 152 | ) |
| 153 | if prompt_cache is None: |
| 154 | prompt_cache = make_prompt_cache(self.model, self.max_kv_size) |
| 155 | remaining_tokens = cache_key |
| 156 | |
| 157 | # Build generation parameters using request attributes and options |
| 158 | max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request) |
| 159 | |
| 160 | print( |
| 161 | f"Generating text with MLX - max_tokens: {max_tokens}, " |
| 162 | f"cache_hit: {len(remaining_tokens) < len(cache_key)}", |
| 163 | file=sys.stderr, |
| 164 | ) |
| 165 | |
| 166 | # Create sampler and optional logits processors (penalties) |
| 167 | sampler = make_sampler(**sampler_params) |
| 168 | logits_processors = make_logits_processors(**logits_params) if logits_params else None |
| 169 | |
| 170 | # Use stream_generate to collect text + track tokens for cache key |
| 171 | generated_text = [] |
| 172 | last_response = None |
| 173 | for response in stream_generate( |
| 174 | self.model, |
| 175 | self.tokenizer, |
| 176 | prompt=remaining_tokens if remaining_tokens else cache_key, |
| 177 | max_tokens=max_tokens, |
| 178 | sampler=sampler, |
| 179 | logits_processors=logits_processors, |
| 180 | prompt_cache=prompt_cache, |
| 181 | ): |
| 182 | generated_text.append(response.text) |
| 183 | cache_key.append(response.token) |
| 184 | last_response = response |
| 185 | # Early stop on user-provided stop sequences |
nothing calls this directly
no test coverage detected