(self, params)
| 75 | self.init_heart_beat() |
| 76 | |
| 77 | async def generate_stream(self, params): |
| 78 | self.call_ct += 1 |
| 79 | |
| 80 | context = params.pop("prompt") |
| 81 | request_id = params.pop("request_id") |
| 82 | temperature = float(params.get("temperature", 1.0)) |
| 83 | top_p = float(params.get("top_p", 1.0)) |
| 84 | top_k = params.get("top_k", -1.0) |
| 85 | presence_penalty = float(params.get("presence_penalty", 0.0)) |
| 86 | frequency_penalty = float(params.get("frequency_penalty", 0.0)) |
| 87 | max_new_tokens = params.get("max_new_tokens", 256) |
| 88 | stop_str = params.get("stop", None) |
| 89 | stop_token_ids = params.get("stop_token_ids", None) or [] |
| 90 | if self.tokenizer.eos_token_id is not None: |
| 91 | stop_token_ids.append(self.tokenizer.eos_token_id) |
| 92 | echo = params.get("echo", True) |
| 93 | use_beam_search = params.get("use_beam_search", False) |
| 94 | best_of = params.get("best_of", None) |
| 95 | |
| 96 | # Handle stop_str |
| 97 | stop = set() |
| 98 | if isinstance(stop_str, str) and stop_str != "": |
| 99 | stop.add(stop_str) |
| 100 | elif isinstance(stop_str, list) and stop_str != []: |
| 101 | stop.update(stop_str) |
| 102 | |
| 103 | for tid in stop_token_ids: |
| 104 | if tid is not None: |
| 105 | s = self.tokenizer.decode(tid) |
| 106 | if s != "": |
| 107 | stop.add(s) |
| 108 | |
| 109 | print("Stop patterns: ", stop) |
| 110 | |
| 111 | top_p = max(top_p, 1e-5) |
| 112 | if temperature <= 1e-5: |
| 113 | top_p = 1.0 |
| 114 | |
| 115 | tokens = [] |
| 116 | skip = 0 |
| 117 | |
| 118 | context_mlx = mx.array(self.tokenizer.encode(context)) |
| 119 | |
| 120 | finish_reason = "length" |
| 121 | |
| 122 | iterator = await run_in_threadpool( |
| 123 | generate_step, context_mlx, self.mlx_model, temperature |
| 124 | ) |
| 125 | |
| 126 | for i in range(max_new_tokens): |
| 127 | (token, _) = await run_in_threadpool(next, iterator) |
| 128 | if token == self.mlx_tokenizer.eos_token_id: |
| 129 | finish_reason = "stop" |
| 130 | break |
| 131 | tokens.append(token.item()) |
| 132 | tokens_decoded = self.mlx_tokenizer.decode(tokens) |
| 133 | last_token_decoded = self.mlx_tokenizer.decode([token.item()]) |
| 134 | skip = len(tokens_decoded) |
no test coverage detected