(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: bool = 1.0, sample: bool = False, top_k: int = None)
| 272 | |
| 273 | |
| 274 | def generate(model, tokenizer, prompt: str, n_tokens_to_gen: int = 10, temp: bool = 1.0, sample: bool = False, top_k: int = None): |
| 275 | tks = tokenizer(prompt)["input_ids"] |
| 276 | while len(tks) < 4: |
| 277 | tks = [50279] + tks |
| 278 | |
| 279 | # Loading in the prompt tokens |
| 280 | logits = model.forward(Tensor([tks]))[:, -1, :] |
| 281 | for _ in tqdm(range(n_tokens_to_gen), desc="Speed Gen"): |
| 282 | if sample: |
| 283 | scaled_logits = logits / temp |
| 284 | if top_k is not None: |
| 285 | topk_values, topk_indices = scaled_logits.topk(top_k) |
| 286 | filtered_logits = Tensor.full_like(scaled_logits, -float("inf")) |
| 287 | filtered_logits = filtered_logits.scatter(dim=-1, index=topk_indices, src=topk_values) |
| 288 | tok_Tens = filtered_logits.softmax().multinomial() |
| 289 | else: |
| 290 | tok_Tens = scaled_logits.softmax().multinomial() |
| 291 | else: |
| 292 | tok_Tens = logits.argmax(axis=-1).unsqueeze(0) |
| 293 | tok = tok_Tens.item() |
| 294 | tks.append(tok) |
| 295 | logits = model.forward_jit(tok_Tens)[:, -1, :] |
| 296 | |
| 297 | output_completions = ''.join([tokenizer.decode(output) for output in tks]) |
| 298 | return output_completions |
| 299 | |
| 300 | if __name__ == "__main__": |
| 301 | ORIG_PROMPT = "Why is gravity " |
searching dependent graphs…