Generates text based on the given prompt and sampling parameters, and streams the results using MLX. Uses thread-safe LRU prompt cache for efficient prefix reuse across requests. Args: request: The predict stream request. context: The gRPC context.
(self, request, context)
| 289 | return backend_pb2.Result(success=False, message=str(e)) |
| 290 | |
| 291 | async def PredictStream(self, request, context): |
| 292 | """ |
| 293 | Generates text based on the given prompt and sampling parameters, and streams the results using MLX. |
| 294 | |
| 295 | Uses thread-safe LRU prompt cache for efficient prefix reuse across requests. |
| 296 | |
| 297 | Args: |
| 298 | request: The predict stream request. |
| 299 | context: The gRPC context. |
| 300 | |
| 301 | Yields: |
| 302 | backend_pb2.Reply: Streaming predict results. |
| 303 | """ |
| 304 | prompt_cache = None |
| 305 | cache_key = None |
| 306 | |
| 307 | try: |
| 308 | # Prepare the prompt and tokenize for cache key |
| 309 | prompt_text = self._prepare_prompt(request) |
| 310 | cache_key = self._get_tokens_from_prompt(prompt_text) |
| 311 | |
| 312 | # Fetch nearest cache (exact, shorter prefix, or create new) |
| 313 | prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( |
| 314 | self.model_key, cache_key |
| 315 | ) |
| 316 | if prompt_cache is None: |
| 317 | prompt_cache = make_prompt_cache(self.model, self.max_kv_size) |
| 318 | remaining_tokens = cache_key |
| 319 | |
| 320 | # Build generation parameters using request attributes and options |
| 321 | max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params( |
| 322 | request, default_max_tokens=512 |
| 323 | ) |
| 324 | |
| 325 | print( |
| 326 | f"Streaming text with MLX - max_tokens: {max_tokens}, " |
| 327 | f"cache_hit: {len(remaining_tokens) < len(cache_key)}", |
| 328 | file=sys.stderr, |
| 329 | ) |
| 330 | |
| 331 | # Create sampler and optional logits processors (penalties) |
| 332 | sampler = make_sampler(**sampler_params) |
| 333 | logits_processors = make_logits_processors(**logits_params) if logits_params else None |
| 334 | |
| 335 | accumulated = [] |
| 336 | last_response = None |
| 337 | for response in stream_generate( |
| 338 | self.model, |
| 339 | self.tokenizer, |
| 340 | prompt=remaining_tokens if remaining_tokens else cache_key, |
| 341 | max_tokens=max_tokens, |
| 342 | sampler=sampler, |
| 343 | logits_processors=logits_processors, |
| 344 | prompt_cache=prompt_cache, |
| 345 | ): |
| 346 | cache_key.append(response.token) |
| 347 | accumulated.append(response.text) |
| 348 | last_response = response |
nothing calls this directly
no test coverage detected