MCPcopy
hub / github.com/karpathy/makemore / generate

Function generate

makemore.py:429–457  ·  view source on GitHub ↗

Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. Most likely you'll want to make sure to be in model.eval() mode of operation for this.

(model, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None)

Source from the content-addressed store, hash-verified

427
428@torch.no_grad()
429def generate(model, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None):
430 """
431 Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete
432 the sequence max_new_tokens times, feeding the predictions back into the model each time.
433 Most likely you'll want to make sure to be in model.eval() mode of operation for this.
434 """
435 block_size = model.get_block_size()
436 for _ in range(max_new_tokens):
437 # if the sequence context is growing too long we must crop it at block_size
438 idx_cond = idx if idx.size(1) <= block_size else idx[:, -block_size:]
439 # forward the model to get the logits for the index in the sequence
440 logits, _ = model(idx_cond)
441 # pluck the logits at the final step and scale by desired temperature
442 logits = logits[:, -1, :] / temperature
443 # optionally crop the logits to only the top k options
444 if top_k is not None:
445 v, _ = torch.topk(logits, top_k)
446 logits[logits < v[:, [-1]]] = -float('Inf')
447 # apply softmax to convert logits to (normalized) probabilities
448 probs = F.softmax(logits, dim=-1)
449 # either sample from the distribution or take the most likely element
450 if do_sample:
451 idx_next = torch.multinomial(probs, num_samples=1)
452 else:
453 _, idx_next = torch.topk(probs, k=1, dim=-1)
454 # append sampled index to the running sequence and continue
455 idx = torch.cat((idx, idx_next), dim=1)
456
457 return idx
458
459def print_samples(num=10):
460 """ samples from the model and pretty prints the decoded samples """

Callers 1

print_samplesFunction · 0.85

Calls 1

get_block_sizeMethod · 0.45

Tested by

no test coverage detected