MCPcopy
hub / github.com/lm-sys/FastChat / generate_stream

Method generate_stream

fastchat/serve/mlx_worker.py:77–163  ·  view source on GitHub ↗
(self, params)

Source from the content-addressed store, hash-verified

75 self.init_heart_beat()
76
77 async def generate_stream(self, params):
78 self.call_ct += 1
79
80 context = params.pop("prompt")
81 request_id = params.pop("request_id")
82 temperature = float(params.get("temperature", 1.0))
83 top_p = float(params.get("top_p", 1.0))
84 top_k = params.get("top_k", -1.0)
85 presence_penalty = float(params.get("presence_penalty", 0.0))
86 frequency_penalty = float(params.get("frequency_penalty", 0.0))
87 max_new_tokens = params.get("max_new_tokens", 256)
88 stop_str = params.get("stop", None)
89 stop_token_ids = params.get("stop_token_ids", None) or []
90 if self.tokenizer.eos_token_id is not None:
91 stop_token_ids.append(self.tokenizer.eos_token_id)
92 echo = params.get("echo", True)
93 use_beam_search = params.get("use_beam_search", False)
94 best_of = params.get("best_of", None)
95
96 # Handle stop_str
97 stop = set()
98 if isinstance(stop_str, str) and stop_str != "":
99 stop.add(stop_str)
100 elif isinstance(stop_str, list) and stop_str != []:
101 stop.update(stop_str)
102
103 for tid in stop_token_ids:
104 if tid is not None:
105 s = self.tokenizer.decode(tid)
106 if s != "":
107 stop.add(s)
108
109 print("Stop patterns: ", stop)
110
111 top_p = max(top_p, 1e-5)
112 if temperature <= 1e-5:
113 top_p = 1.0
114
115 tokens = []
116 skip = 0
117
118 context_mlx = mx.array(self.tokenizer.encode(context))
119
120 finish_reason = "length"
121
122 iterator = await run_in_threadpool(
123 generate_step, context_mlx, self.mlx_model, temperature
124 )
125
126 for i in range(max_new_tokens):
127 (token, _) = await run_in_threadpool(next, iterator)
128 if token == self.mlx_tokenizer.eos_token_id:
129 finish_reason = "stop"
130 break
131 tokens.append(token.item())
132 tokens_decoded = self.mlx_tokenizer.decode(tokens)
133 last_token_decoded = self.mlx_tokenizer.decode([token.item()])
134 skip = len(tokens_decoded)

Callers 2

generateMethod · 0.95
api_generate_streamFunction · 0.45

Calls 1

is_partial_stopFunction · 0.90

Tested by

no test coverage detected