V1 version of: sample outputs from a model, and do it all at once :param news_config: Configuration used to construct the model :param initial_context: [batch_size, seq_length] that we'll start generating with :param eos_token: Stop generating if you see this (tf scalar) :param
(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95,
do_topk=False)
| 746 | |
| 747 | |
| 748 | def sample(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95, |
| 749 | do_topk=False): |
| 750 | """ |
| 751 | V1 version of: sample outputs from a model, and do it all at once |
| 752 | :param news_config: Configuration used to construct the model |
| 753 | :param initial_context: [batch_size, seq_length] that we'll start generating with |
| 754 | :param eos_token: Stop generating if you see this (tf scalar) |
| 755 | :param min_len: min length of sample |
| 756 | :param ignore_ids: NEVER GENERATE THESE [vocab_size] |
| 757 | :return: |
| 758 | """ |
| 759 | batch_size, _ = get_shape_list(initial_context, expected_rank=2) |
| 760 | |
| 761 | if ignore_ids is None: |
| 762 | ignore_ids = tf.constant([x == 0 for x in range(news_config.vocab_size)], dtype=tf.bool) |
| 763 | |
| 764 | with tf.name_scope('sample_sequence'): |
| 765 | # Initial call to get cache |
| 766 | context_output = initialize_from_context(initial_context, ignore_ids=ignore_ids, news_config=news_config, |
| 767 | p_for_topp=p_for_topp, |
| 768 | do_topk=do_topk) |
| 769 | ctx = context_output['tokens'] |
| 770 | cache = context_output['cache'] |
| 771 | probs = context_output['probs'] |
| 772 | |
| 773 | def body(ctx, cache, probs): |
| 774 | """ for whatever reason this didn't work when I ran it on more than one at once... ugh.""" |
| 775 | next_outputs = sample_step(ctx[:, -1][:, None], ignore_ids=ignore_ids, news_config=news_config, |
| 776 | batch_size=batch_size, p_for_topp=p_for_topp, cache=cache, |
| 777 | do_topk=do_topk) |
| 778 | |
| 779 | |
| 780 | # Update everything |
| 781 | new_cache = tf.concat([cache, next_outputs['new_cache']], axis=-2) |
| 782 | new_ids = tf.concat([ctx, next_outputs['new_tokens'][:, None]], axis=1) |
| 783 | new_probs = tf.concat([probs, next_outputs['new_probs'][:, None]], axis=1) |
| 784 | return [new_ids, new_cache, new_probs] |
| 785 | |
| 786 | def cond(ctx, cache, probs): |
| 787 | # ctx = tf.Print(ctx,[tf.shape(ctx)]) |
| 788 | # print('kkkkkkkkkkkkk') |
| 789 | # print(ctx[:,-1:]) |
| 790 | is_eos = tf.reduce_all(tf.reduce_any(tf.equal(ctx[:,-1:], eos_token), axis=1)) |
| 791 | # print('-----------------') |
| 792 | # print(is_eos) |
| 793 | is_len = tf.greater(get_shape_list(ctx)[1], min_len) |
| 794 | return tf.logical_not(tf.logical_and(is_eos, is_len)) |
| 795 | |
| 796 | tokens, cache, probs = tf.while_loop( |
| 797 | cond=cond, body=body, maximum_iterations=1025 - get_shape_list(ctx)[1], |
| 798 | loop_vars=[ctx, cache, probs], |
| 799 | shape_invariants=[tf.TensorShape([batch_size, None]), |
| 800 | tf.TensorShape( |
| 801 | [batch_size, news_config.num_hidden_layers, 2, |
| 802 | news_config.num_attention_heads, |
| 803 | None, news_config.hidden_size // news_config.num_attention_heads]), |
| 804 | tf.TensorShape([batch_size, None]), |
| 805 | ], |
no test coverage detected