MCPcopy
hub / github.com/tinygrad/tinygrad / generate

Function generate

examples/mamba.py:274–298  ·  view source on GitHub ↗
(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: bool = 1.0, sample: bool = False, top_k: int = None)

Source from the content-addressed store, hash-verified

272
273
274def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: bool = 1.0, sample: bool = False, top_k: int = None):
275 tks = tokenizer(prompt)["input_ids"]
276 while len(tks) < 4:
277 tks = [50279] + tks
278
279 # Loading in the prompt tokens
280 logits = model.forward(Tensor([tks]))[:, -1, :]
281 for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"):
282 if sample:
283 scaled_logits = logits / temp
284 if top_k is not None:
285 topk_values, topk_indices = scaled_logits.topk(top_k)
286 filtered_logits = Tensor.full_like(scaled_logits, -float("inf"))
287 filtered_logits = filtered_logits.scatter(dim=-1, index=topk_indices, src=topk_values)
288 tok_Tens = filtered_logits.softmax().multinomial()
289 else:
290 tok_Tens = scaled_logits.softmax().multinomial()
291 else:
292 tok_Tens = logits.argmax(axis=-1).unsqueeze(0)
293 tok = tok_Tens.item()
294 tks.append(tok)
295 logits = model.forward_jit(tok_Tens)[:, -1, :]
296
297 output_completions = ''.join([tokenizer.decode(output) for output in tks])
298 return output_completions
299
300if __name__ == "__main__":
301 ORIG_PROMPT = "Why is gravity "

Callers 3

test_mamba_130MMethod · 0.90
test_mamba_370MMethod · 0.90
mamba.pyFile · 0.85

Calls 13

TensorClass · 0.90
tqdmClass · 0.85
topkMethod · 0.80
scatterMethod · 0.80
multinomialMethod · 0.80
softmaxMethod · 0.80
unsqueezeMethod · 0.80
argmaxMethod · 0.80
itemMethod · 0.80
appendMethod · 0.80
forwardMethod · 0.45
full_likeMethod · 0.45

Tested by 2

test_mamba_130MMethod · 0.72
test_mamba_370MMethod · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…