MCPcopy
hub / github.com/tinygrad/tinygrad / run_model

Method run_model

tinygrad/llm/cli.py:116–141  ·  view source on GitHub ↗
(self, ids:list[int], model_name:str, include_usage=False, max_tokens:int|None=None, temperature:float=0.0)

Source from the content-addressed store, hash-verified

114 if self.path == "/v1/models": self.send_data(json.dumps({"object":"list","data":[{"id":self.server.model_name,"object":"model"}]}).encode())
115 else: self.send_data((pathlib.Path(__file__).parent / "chat.html").read_bytes(), content_type="text/html")
116 def run_model(self, ids:list[int], model_name:str, include_usage=False, max_tokens:int|None=None, temperature:float=0.0):
117 model, tok = self.server.model, self.server.tok
118 cache_start_pos = model.get_start_pos(ids)
119 stderr_log(f"{self.path} {colored('--', 'BLACK')} "
120 f"in:{colored(f'{cache_start_pos:5d}', 'green')} +{len(ids)-cache_start_pos:5d} {colored('--', 'BLACK')} ")
121 tmpl = {"id":f"chatcmpl-{uuid.uuid4().hex[:24]}", "object":"chat.completion.chunk", "created":int(time.time()), "model":model_name}
122 yield {"choices": [{"index":0, "delta":{"role":"assistant","content":""}, "finish_reason":None}], **tmpl}
123 out: list[int] = []
124 finish_reason = "stop"
125 st = time.perf_counter()
126 dec = tok.stream_decoder()
127 for next_id in model.generate(ids, temperature=temperature):
128 if len(out) == 0: stderr_log(f"prefill:{(len(ids)-cache_start_pos)/((pt:=time.perf_counter())-st):4.0f} tok/s {colored('--', 'BLACK')} ")
129 if tok.is_end(next_id): break
130 out.append(next_id)
131 yield {"choices": [{"index":0, "delta":{"content":dec(next_id)}, "finish_reason":None}], **tmpl}
132 if max_tokens is not None and len(out) >= max_tokens:
133 finish_reason = "length"
134 break
135 if (tail := dec()): yield {"choices": [{"index":0, "delta":{"content":tail}, "finish_reason":None}], **tmpl}
136 yield {"choices": [{"index":0, "delta":{},"finish_reason":finish_reason}], **tmpl}
137 if include_usage:
138 yield {"choices": [], "usage": {"prompt_tokens": len(ids), "completion_tokens": len(out), "total_tokens": len(ids) + len(out)}, **tmpl}
139 et = time.perf_counter()
140 stderr_log(f"gen:{len(out)/(et-pt) if len(out) > 1 else 0:4.0f} tok/s {colored('--', 'BLACK')} "
141 f"out:{len(out):5d} {colored('--', 'BLACK')} total:{et-st:6.2f}s\n")
142
143 def do_POST(self):
144 tok = self.server.tok

Callers 1

do_POSTMethod · 0.95

Calls 8

stderr_logFunction · 0.90
coloredFunction · 0.90
decFunction · 0.85
get_start_posMethod · 0.80
stream_decoderMethod · 0.80
is_endMethod · 0.80
appendMethod · 0.80
generateMethod · 0.45

Tested by

no test coverage detected