MCPcopy
hub / github.com/karpathy/makemore / print_samples

Function print_samples

makemore.py:459–485  ·  view source on GitHub ↗

samples from the model and pretty prints the decoded samples

(num=10)

Source from the content-addressed store, hash-verified

457 return idx
458
459def print_samples(num=10):
460 """ samples from the model and pretty prints the decoded samples """
461 X_init = torch.zeros(num, 1, dtype=torch.long).to(args.device)
462 top_k = args.top_k if args.top_k != -1 else None
463 steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0)
464 X_samp = generate(model, X_init, steps, top_k=top_k, do_sample=True).to('cpu')
465 train_samples, test_samples, new_samples = [], [], []
466 for i in range(X_samp.size(0)):
467 # get the i'th row of sampled integers, as python list
468 row = X_samp[i, 1:].tolist() # note: we need to crop out the first <START> token
469 # token 0 is the <STOP> token, so we crop the output sequence at that point
470 crop_index = row.index(0) if 0 in row else len(row)
471 row = row[:crop_index]
472 word_samp = train_dataset.decode(row)
473 # separately track samples that we have and have not seen before
474 if train_dataset.contains(word_samp):
475 train_samples.append(word_samp)
476 elif test_dataset.contains(word_samp):
477 test_samples.append(word_samp)
478 else:
479 new_samples.append(word_samp)
480 print('-'*80)
481 for lst, desc in [(train_samples, 'in train'), (test_samples, 'in test'), (new_samples, 'new')]:
482 print(f"{len(lst)} samples that are {desc}:")
483 for word in lst:
484 print(word)
485 print('-'*80)
486
487@torch.inference_mode()
488def evaluate(model, dataset, batch_size=50, max_batches=None):

Callers 1

makemore.pyFile · 0.85

Calls 4

generateFunction · 0.85
get_output_lengthMethod · 0.80
decodeMethod · 0.80
containsMethod · 0.80

Tested by

no test coverage detected