Helper function that samples from grover for a single step :param tokens: [batch_size, n_ctx_b] tokens that we will predict from :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict :param news_config: config for the GroverModel :param batch_size: batch size
(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False)
| 688 | |
| 689 | |
| 690 | def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False): |
| 691 | """ |
| 692 | Helper function that samples from grover for a single step |
| 693 | :param tokens: [batch_size, n_ctx_b] tokens that we will predict from |
| 694 | :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict |
| 695 | :param news_config: config for the GroverModel |
| 696 | :param batch_size: batch size to use |
| 697 | :param p_for_topp: top-p or top-k threshold |
| 698 | :param cache: [batch_size, news_config.num_hidden_layers, 2, |
| 699 | news_config.num_attention_heads, n_ctx_a, |
| 700 | news_config.hidden_size // news_config.num_attention_heads] OR, None |
| 701 | :return: new_tokens, size [batch_size] |
| 702 | new_probs, also size [batch_size] |
| 703 | new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b, |
| 704 | news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads] |
| 705 | """ |
| 706 | model = GroverModel( |
| 707 | config=news_config, |
| 708 | is_training=False, |
| 709 | input_ids=tokens, |
| 710 | reuse=tf.compat.v1.AUTO_REUSE, |
| 711 | scope='newslm', |
| 712 | chop_off_last_token=False, |
| 713 | do_cache=True, |
| 714 | cache=cache, |
| 715 | ) |
| 716 | # Extract the FINAL SEQ LENGTH |
| 717 | batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2) |
| 718 | next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1] |
| 719 | |
| 720 | if do_topk: |
| 721 | sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32)) |
| 722 | else: |
| 723 | sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp) |
| 724 | |
| 725 | new_tokens = tf.squeeze(sample_info['sample'], 1) |
| 726 | new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1) |
| 727 | |
| 728 | return { |
| 729 | 'new_tokens': new_tokens, |
| 730 | 'new_probs': new_probs, |
| 731 | 'new_cache': model.new_kvs, |
| 732 | } |
| 733 | |
| 734 | |
| 735 | def initialize_from_context(initial_context, ignore_ids, news_config, p_for_topp=0.95, do_topk=False): |
no test coverage detected