MCPcopy
hub / github.com/karpathy/minGPT / generate

Method generate

mingpt/model.py:283–310  ·  view source on GitHub ↗

Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. Most likely you'll want to make sure to be in model.eval() mode of operation for this.

(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None)

Source from the content-addressed store, hash-verified

281
282 @torch.no_grad()
283 def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
284 """
285 Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
286 the sequence max_new_tokens times, feeding the predictions back into the model each time.
287 Most likely you'll want to make sure to be in model.eval() mode of operation for this.
288 """
289 for _ in range(max_new_tokens):
290 # if the sequence context is growing too long we must crop it at block_size
291 idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
292 # forward the model to get the logits for the index in the sequence
293 logits, _ = self(idx_cond)
294 # pluck the logits at the final step and scale by desired temperature
295 logits = logits[:, -1, :] / temperature
296 # optionally crop the logits to only the top k options
297 if top_k is not None:
298 v, _ = torch.topk(logits, top_k)
299 logits[logits < v[:, [-1]]] = -float('Inf')
300 # apply softmax to convert logits to (normalized) probabilities
301 probs = F.softmax(logits, dim=-1)
302 # either sample from the distribution or take the most likely element
303 if do_sample:
304 idx_next = torch.multinomial(probs, num_samples=1)
305 else:
306 _, idx_next = torch.topk(probs, k=1, dim=-1)
307 # append sampled index to the running sequence and continue
308 idx = torch.cat((idx, idx_next), dim=1)
309
310 return idx

Callers 3

batch_end_callbackFunction · 0.80
eval_splitFunction · 0.80
test_gpt2Method · 0.80

Calls

no outgoing calls

Tested by 1

test_gpt2Method · 0.64