Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. Most likely you'll want to make sure to be in model.eval() mode of operation for this.
(model, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None)
| 427 | |
| 428 | @torch.no_grad() |
| 429 | def generate(model, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None): |
| 430 | """ |
| 431 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete |
| 432 | the sequence max_new_tokens times, feeding the predictions back into the model each time. |
| 433 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. |
| 434 | """ |
| 435 | block_size = model.get_block_size() |
| 436 | for _ in range(max_new_tokens): |
| 437 | # if the sequence context is growing too long we must crop it at block_size |
| 438 | idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:] |
| 439 | # forward the model to get the logits for the index in the sequence |
| 440 | logits, _ = model(idx_cond) |
| 441 | # pluck the logits at the final step and scale by desired temperature |
| 442 | logits = logits[:, -1, :] / temperature |
| 443 | # optionally crop the logits to only the top k options |
| 444 | if top_k is not None: |
| 445 | v, _ = torch.topk(logits, top_k) |
| 446 | logits[logits < v[:, [-1]]] = -float('Inf') |
| 447 | # apply softmax to convert logits to (normalized) probabilities |
| 448 | probs = F.softmax(logits, dim=-1) |
| 449 | # either sample from the distribution or take the most likely element |
| 450 | if do_sample: |
| 451 | idx_next = torch.multinomial(probs, num_samples=1) |
| 452 | else: |
| 453 | _, idx_next = torch.topk(probs, k=1, dim=-1) |
| 454 | # append sampled index to the running sequence and continue |
| 455 | idx = torch.cat((idx, idx_next), dim=1) |
| 456 | |
| 457 | return idx |
| 458 | |
| 459 | def print_samples(num=10): |
| 460 | """ samples from the model and pretty prints the decoded samples """ |
no test coverage detected