MCPcopy
hub / github.com/Turing-Project/WriteGPT / sample

Function sample

LanguageNetwork/GPT2/scripts/modeling.py:748–812  ·  view source on GitHub ↗

V1 version of: sample outputs from a model, and do it all at once :param news_config: Configuration used to construct the model :param initial_context: [batch_size, seq_length] that we'll start generating with :param eos_token: Stop generating if you see this (tf scalar) :param

(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95,
           do_topk=False)

Source from the content-addressed store, hash-verified

746
747
748def sample(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95,
749 do_topk=False):
750 """
751 V1 version of: sample outputs from a model, and do it all at once
752 :param news_config: Configuration used to construct the model
753 :param initial_context: [batch_size, seq_length] that we'll start generating with
754 :param eos_token: Stop generating if you see this (tf scalar)
755 :param min_len: min length of sample
756 :param ignore_ids: NEVER GENERATE THESE [vocab_size]
757 :return:
758 """
759 batch_size, _ = get_shape_list(initial_context, expected_rank=2)
760
761 if ignore_ids is None:
762 ignore_ids = tf.constant([x == 0 for x in range(news_config.vocab_size)], dtype=tf.bool)
763
764 with tf.name_scope('sample_sequence'):
765 # Initial call to get cache
766 context_output = initialize_from_context(initial_context, ignore_ids=ignore_ids, news_config=news_config,
767 p_for_topp=p_for_topp,
768 do_topk=do_topk)
769 ctx = context_output['tokens']
770 cache = context_output['cache']
771 probs = context_output['probs']
772
773 def body(ctx, cache, probs):
774 """ for whatever reason this didn't work when I ran it on more than one at once... ugh."""
775 next_outputs = sample_step(ctx[:, -1][:, None], ignore_ids=ignore_ids, news_config=news_config,
776 batch_size=batch_size, p_for_topp=p_for_topp, cache=cache,
777 do_topk=do_topk)
778
779
780 # Update everything
781 new_cache = tf.concat([cache, next_outputs['new_cache']], axis=-2)
782 new_ids = tf.concat([ctx, next_outputs['new_tokens'][:, None]], axis=1)
783 new_probs = tf.concat([probs, next_outputs['new_probs'][:, None]], axis=1)
784 return [new_ids, new_cache, new_probs]
785
786 def cond(ctx, cache, probs):
787 # ctx = tf.Print(ctx,[tf.shape(ctx)])
788 # print('kkkkkkkkkkkkk')
789 # print(ctx[:,-1:])
790 is_eos = tf.reduce_all(tf.reduce_any(tf.equal(ctx[:,-1:], eos_token), axis=1))
791 # print('-----------------')
792 # print(is_eos)
793 is_len = tf.greater(get_shape_list(ctx)[1], min_len)
794 return tf.logical_not(tf.logical_and(is_eos, is_len))
795
796 tokens, cache, probs = tf.while_loop(
797 cond=cond, body=body, maximum_iterations=1025 - get_shape_list(ctx)[1],
798 loop_vars=[ctx, cache, probs],
799 shape_invariants=[tf.TensorShape([batch_size, None]),
800 tf.TensorShape(
801 [batch_size, news_config.num_hidden_layers, 2,
802 news_config.num_attention_heads,
803 None, news_config.hidden_size // news_config.num_attention_heads]),
804 tf.TensorShape([batch_size, None]),
805 ],

Callers 2

demo.pyFile · 0.90

Calls 2

get_shape_listFunction · 0.90
initialize_from_contextFunction · 0.70

Tested by

no test coverage detected