samples from the model and pretty prints the decoded samples
(num=10)
| 457 | return idx |
| 458 | |
| 459 | def print_samples(num=10): |
| 460 | """ samples from the model and pretty prints the decoded samples """ |
| 461 | X_init = torch.zeros(num, 1, dtype=torch.long).to(args.device) |
| 462 | top_k = args.top_k if args.top_k != -1 else None |
| 463 | steps = train_dataset.get_output_length() - 1 # -1 because we already start with <START> token (index 0) |
| 464 | X_samp = generate(model, X_init, steps, top_k=top_k, do_sample=True).to('cpu') |
| 465 | train_samples, test_samples, new_samples = [], [], [] |
| 466 | for i in range(X_samp.size(0)): |
| 467 | # get the i'th row of sampled integers, as python list |
| 468 | row = X_samp[i, 1:].tolist() # note: we need to crop out the first <START> token |
| 469 | # token 0 is the <STOP> token, so we crop the output sequence at that point |
| 470 | crop_index = row.index(0) if 0 in row else len(row) |
| 471 | row = row[:crop_index] |
| 472 | word_samp = train_dataset.decode(row) |
| 473 | # separately track samples that we have and have not seen before |
| 474 | if train_dataset.contains(word_samp): |
| 475 | train_samples.append(word_samp) |
| 476 | elif test_dataset.contains(word_samp): |
| 477 | test_samples.append(word_samp) |
| 478 | else: |
| 479 | new_samples.append(word_samp) |
| 480 | print('-'*80) |
| 481 | for lst, desc in [(train_samples, 'in train'), (test_samples, 'in test'), (new_samples, 'new')]: |
| 482 | print(f"{len(lst)} samples that are {desc}:") |
| 483 | for word in lst: |
| 484 | print(word) |
| 485 | print('-'*80) |
| 486 | |
| 487 | @torch.inference_mode() |
| 488 | def evaluate(model, dataset, batch_size=50, max_batches=None): |
no test coverage detected