(self, request, context)
| 276 | return backend_pb2.Reply(message=bytes("", encoding='utf-8')) |
| 277 | |
| 278 | async def PredictStream(self, request, context): |
| 279 | prompt_cache = None |
| 280 | cache_key = None |
| 281 | |
| 282 | try: |
| 283 | import mlx.core as mx |
| 284 | from mlx_lm import stream_generate |
| 285 | from mlx_lm.sample_utils import make_logits_processors, make_sampler |
| 286 | |
| 287 | prompt_text = self._prepare_prompt(request) |
| 288 | tokens = self._get_tokens_from_prompt(prompt_text) |
| 289 | |
| 290 | if self.coordinator: |
| 291 | from coordinator import CMD_GENERATE |
| 292 | self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) |
| 293 | self.coordinator.broadcast_tokens(tokens) |
| 294 | |
| 295 | max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params( |
| 296 | request, default_max_tokens=512 |
| 297 | ) |
| 298 | |
| 299 | if self.coordinator: |
| 300 | gen_params = self.coordinator.broadcast_generation_params( |
| 301 | max_tokens=max_tokens, |
| 302 | temperature=sampler_params.get('temp', 0.6), |
| 303 | top_p=sampler_params.get('top_p', 1.0), |
| 304 | ) |
| 305 | max_tokens = gen_params["max_tokens"] |
| 306 | |
| 307 | sampler = make_sampler(**sampler_params) |
| 308 | logits_processors = make_logits_processors(**logits_params) if logits_params else None |
| 309 | |
| 310 | # Use prompt cache in single-node mode |
| 311 | gen_kwargs = {} |
| 312 | if self.lru_cache is not None: |
| 313 | from mlx_lm.models.cache import make_prompt_cache |
| 314 | cache_key = list(tokens) |
| 315 | prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( |
| 316 | self.model_key, cache_key |
| 317 | ) |
| 318 | if prompt_cache is None: |
| 319 | prompt_cache = make_prompt_cache(self.model, self.max_kv_size) |
| 320 | remaining_tokens = cache_key |
| 321 | gen_kwargs['prompt_cache'] = prompt_cache |
| 322 | tokens = remaining_tokens if remaining_tokens else cache_key |
| 323 | |
| 324 | accumulated = [] |
| 325 | last_response = None |
| 326 | for response in stream_generate( |
| 327 | self.model, |
| 328 | self.tokenizer, |
| 329 | prompt=tokens, |
| 330 | max_tokens=max_tokens, |
| 331 | sampler=sampler, |
| 332 | logits_processors=logits_processors, |
| 333 | **gen_kwargs, |
| 334 | ): |
| 335 | if cache_key is not None: |
nothing calls this directly
no test coverage detected