MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / sample_step

Function sample_step

LanguageNetwork/GPT2/train/modeling.py:686–728  ·  view source on GitHub ↗

Helper function that samples from grover for a single step :param tokens: [batch_size, n_ctx_b] tokens that we will predict from :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict :param news_config: config for the GroverModel :param batch_size: batch size

(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False)

Source from the content-addressed store, hash-verified

684
685
686def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False):
687 """
688 Helper function that samples from grover for a single step
689 :param tokens: [batch_size, n_ctx_b] tokens that we will predict from
690 :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict
691 :param news_config: config for the GroverModel
692 :param batch_size: batch size to use
693 :param p_for_topp: top-p or top-k threshold
694 :param cache: [batch_size, news_config.num_hidden_layers, 2,
695 news_config.num_attention_heads, n_ctx_a,
696 news_config.hidden_size // news_config.num_attention_heads] OR, None
697 :return: new_tokens, size [batch_size]
698 new_probs, also size [batch_size]
699 new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b,
700 news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads]
701 """
702 model = GroverModel(
703 config=news_config,
704 is_training=False,
705 input_ids=tokens,
706 reuse=tf.AUTO_REUSE,
707 scope='newslm',
708 chop_off_last_token=False,
709 do_cache=True,
710 cache=cache,
711 )
712 # Extract the FINAL SEQ LENGTH
713 batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2)
714 next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1]
715
716 if do_topk:
717 sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32))
718 else:
719 sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp)
720
721 new_tokens = tf.squeeze(sample_info['sample'], 1)
722 new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1)
723
724 return {
725 'new_tokens': new_tokens,
726 'new_probs': new_probs,
727 'new_cache': model.new_kvs,
728 }
729
730
731def initialize_from_context(initial_context, ignore_ids, news_config, p_for_topp=0.95, do_topk=False):

Callers 2

initialize_from_contextFunction · 0.70
bodyFunction · 0.70

Calls 4

get_shape_listFunction · 0.90
GroverModelClass · 0.70
_top_k_sampleFunction · 0.70
_top_p_sampleFunction · 0.70

Tested by

no test coverage detected