(self, request, context)
| 186 | return backend_pb2.Result(message="Model loaded successfully", success=True) |
| 187 | |
| 188 | async def Predict(self, request, context): |
| 189 | prompt_cache = None |
| 190 | cache_key = None |
| 191 | |
| 192 | try: |
| 193 | import mlx.core as mx |
| 194 | from mlx_lm import stream_generate |
| 195 | from mlx_lm.sample_utils import make_logits_processors, make_sampler |
| 196 | |
| 197 | prompt_text = self._prepare_prompt(request) |
| 198 | tokens = self._get_tokens_from_prompt(prompt_text) |
| 199 | |
| 200 | if self.coordinator: |
| 201 | from coordinator import CMD_GENERATE |
| 202 | self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) |
| 203 | self.coordinator.broadcast_tokens(tokens) |
| 204 | |
| 205 | max_tokens, sampler_params, logits_params, stop_words = self._build_generation_params(request) |
| 206 | |
| 207 | if self.coordinator: |
| 208 | gen_params = self.coordinator.broadcast_generation_params( |
| 209 | max_tokens=max_tokens, |
| 210 | temperature=sampler_params.get('temp', 0.6), |
| 211 | top_p=sampler_params.get('top_p', 1.0), |
| 212 | ) |
| 213 | max_tokens = gen_params["max_tokens"] |
| 214 | |
| 215 | sampler = make_sampler(**sampler_params) |
| 216 | logits_processors = make_logits_processors(**logits_params) if logits_params else None |
| 217 | |
| 218 | # Use prompt cache in single-node mode |
| 219 | gen_kwargs = {} |
| 220 | if self.lru_cache is not None: |
| 221 | from mlx_lm.models.cache import make_prompt_cache |
| 222 | cache_key = list(tokens) |
| 223 | prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( |
| 224 | self.model_key, cache_key |
| 225 | ) |
| 226 | if prompt_cache is None: |
| 227 | prompt_cache = make_prompt_cache(self.model, self.max_kv_size) |
| 228 | remaining_tokens = cache_key |
| 229 | gen_kwargs['prompt_cache'] = prompt_cache |
| 230 | tokens = remaining_tokens if remaining_tokens else cache_key |
| 231 | |
| 232 | generated = [] |
| 233 | last_response = None |
| 234 | for response in stream_generate( |
| 235 | self.model, |
| 236 | self.tokenizer, |
| 237 | prompt=tokens, |
| 238 | max_tokens=max_tokens, |
| 239 | sampler=sampler, |
| 240 | logits_processors=logits_processors, |
| 241 | **gen_kwargs, |
| 242 | ): |
| 243 | generated.append(response.text) |
| 244 | last_response = response |
| 245 | if cache_key is not None: |
nothing calls this directly
no test coverage detected