MCPcopy
hub / github.com/Turing-Project/WriteGPT / sample_step

Function sample_step

LanguageNetwork/GPT2/scripts/modeling.py:690–732  ·  view source on GitHub ↗

Helper function that samples from grover for a single step :param tokens: [batch_size, n_ctx_b] tokens that we will predict from :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict :param news_config: config for the GroverModel :param batch_size: batch size

(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False)

Source from the content-addressed store, hash-verified

688
689
690def sample_step(tokens, ignore_ids, news_config, batch_size=1, p_for_topp=0.95, cache=None, do_topk=False):
691 """
692 Helper function that samples from grover for a single step
693 :param tokens: [batch_size, n_ctx_b] tokens that we will predict from
694 :param ignore_ids: [n_vocab] mask of the tokens we don't want to predict
695 :param news_config: config for the GroverModel
696 :param batch_size: batch size to use
697 :param p_for_topp: top-p or top-k threshold
698 :param cache: [batch_size, news_config.num_hidden_layers, 2,
699 news_config.num_attention_heads, n_ctx_a,
700 news_config.hidden_size // news_config.num_attention_heads] OR, None
701 :return: new_tokens, size [batch_size]
702 new_probs, also size [batch_size]
703 new_cache, size [batch_size, news_config.num_hidden_layers, 2, n_ctx_b,
704 news_config.num_attention_heads, news_config.hidden_size // news_config.num_attention_heads]
705 """
706 model = GroverModel(
707 config=news_config,
708 is_training=False,
709 input_ids=tokens,
710 reuse=tf.compat.v1.AUTO_REUSE,
711 scope='newslm',
712 chop_off_last_token=False,
713 do_cache=True,
714 cache=cache,
715 )
716 # Extract the FINAL SEQ LENGTH
717 batch_size_times_seq_length, vocab_size = get_shape_list(model.logits_flat, expected_rank=2)
718 next_logits = tf.reshape(model.logits_flat, [batch_size, -1, vocab_size])[:, -1]
719
720 if do_topk:
721 sample_info = _top_k_sample(next_logits, num_samples=1, k=tf.cast(p_for_topp, dtype=tf.int32))
722 else:
723 sample_info = _top_p_sample(next_logits, ignore_ids=ignore_ids, num_samples=1, p=p_for_topp)
724
725 new_tokens = tf.squeeze(sample_info['sample'], 1)
726 new_probs = tf.squeeze(tf.batch_gather(sample_info['probs'], sample_info['sample']), 1)
727
728 return {
729 'new_tokens': new_tokens,
730 'new_probs': new_probs,
731 'new_cache': model.new_kvs,
732 }
733
734
735def initialize_from_context(initial_context, ignore_ids, news_config, p_for_topp=0.95, do_topk=False):

Callers 2

initialize_from_contextFunction · 0.70
bodyFunction · 0.70

Calls 4

get_shape_listFunction · 0.90
GroverModelClass · 0.70
_top_k_sampleFunction · 0.70
_top_p_sampleFunction · 0.70

Tested by

no test coverage detected