MCPcopy Index your code
hub / github.com/Turing-Project/WriteGPT / sample

Function sample

LanguageNetwork/GPT2/train/modeling.py:744–808  ·  view source on GitHub ↗

V1 version of: sample outputs from a model, and do it all at once :param news_config: Configuration used to construct the model :param initial_context: [batch_size, seq_length] that we'll start generating with :param eos_token: Stop generating if you see this (tf scalar) :param

(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95,
           do_topk=False)

Source from the content-addressed store, hash-verified

742
743
744def sample(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95,
745 do_topk=False):
746 """
747 V1 version of: sample outputs from a model, and do it all at once
748 :param news_config: Configuration used to construct the model
749 :param initial_context: [batch_size, seq_length] that we'll start generating with
750 :param eos_token: Stop generating if you see this (tf scalar)
751 :param min_len: min length of sample
752 :param ignore_ids: NEVER GENERATE THESE [vocab_size]
753 :return:
754 """
755 batch_size, _ = get_shape_list(initial_context, expected_rank=2)
756
757 if ignore_ids is None:
758 ignore_ids = tf.constant([x == 0 for x in range(news_config.vocab_size)], dtype=tf.bool)
759
760 with tf.name_scope('sample_sequence'):
761 # Initial call to get cache
762 context_output = initialize_from_context(initial_context, ignore_ids=ignore_ids, news_config=news_config,
763 p_for_topp=p_for_topp,
764 do_topk=do_topk)
765 ctx = context_output['tokens']
766 cache = context_output['cache']
767 probs = context_output['probs']
768
769 def body(ctx, cache, probs):
770 """ for whatever reason this didn't work when I ran it on more than one at once... ugh."""
771 next_outputs = sample_step(ctx[:, -1][:, None], ignore_ids=ignore_ids, news_config=news_config,
772 batch_size=batch_size, p_for_topp=p_for_topp, cache=cache,
773 do_topk=do_topk)
774
775
776 # Update everything
777 new_cache = tf.concat([cache, next_outputs['new_cache']], axis=-2)
778 new_ids = tf.concat([ctx, next_outputs['new_tokens'][:, None]], axis=1)
779 new_probs = tf.concat([probs, next_outputs['new_probs'][:, None]], axis=1)
780 return [new_ids, new_cache, new_probs]
781
782 def cond(ctx, cache, probs):
783 # ctx = tf.Print(ctx,[tf.shape(ctx)])
784 # print('kkkkkkkkkkkkk')
785 # print(ctx[:,-1:])
786 is_eos = tf.reduce_all(tf.reduce_any(tf.equal(ctx[:,-1:], eos_token), axis=1))
787 # print('-----------------')
788 # print(is_eos)
789 is_len = tf.greater(get_shape_list(ctx)[1], min_len)
790 return tf.logical_not(tf.logical_and(is_eos, is_len))
791
792 tokens, cache, probs = tf.while_loop(
793 cond=cond, body=body, maximum_iterations=1025 - get_shape_list(ctx)[1],
794 loop_vars=[ctx, cache, probs],
795 shape_invariants=[tf.TensorShape([batch_size, None]),
796 tf.TensorShape(
797 [batch_size, news_config.num_hidden_layers, 2,
798 news_config.num_attention_heads,
799 None, news_config.hidden_size // news_config.num_attention_heads]),
800 tf.TensorShape([batch_size, None]),
801 ],

Callers

nothing calls this directly

Calls 2

get_shape_listFunction · 0.90
initialize_from_contextFunction · 0.70

Tested by

no test coverage detected