Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. Most likely you'll want to make sure to be in model.eval() mode of operation for this.
(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None)
| 281 | |
| 282 | @torch.no_grad() |
| 283 | def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None): |
| 284 | """ |
| 285 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete |
| 286 | the sequence max_new_tokens times, feeding the predictions back into the model each time. |
| 287 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. |
| 288 | """ |
| 289 | for _ in range(max_new_tokens): |
| 290 | # if the sequence context is growing too long we must crop it at block_size |
| 291 | idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:] |
| 292 | # forward the model to get the logits for the index in the sequence |
| 293 | logits, _ = self(idx_cond) |
| 294 | # pluck the logits at the final step and scale by desired temperature |
| 295 | logits = logits[:, -1, :] / temperature |
| 296 | # optionally crop the logits to only the top k options |
| 297 | if top_k is not None: |
| 298 | v, _ = torch.topk(logits, top_k) |
| 299 | logits[logits < v[:, [-1]]] = -float('Inf') |
| 300 | # apply softmax to convert logits to (normalized) probabilities |
| 301 | probs = F.softmax(logits, dim=-1) |
| 302 | # either sample from the distribution or take the most likely element |
| 303 | if do_sample: |
| 304 | idx_next = torch.multinomial(probs, num_samples=1) |
| 305 | else: |
| 306 | _, idx_next = torch.topk(probs, k=1, dim=-1) |
| 307 | # append sampled index to the running sequence and continue |
| 308 | idx = torch.cat((idx, idx_next), dim=1) |
| 309 | |
| 310 | return idx |
no outgoing calls