for whatever reason this didn't work when I ran it on more than one at once... ugh.
(ctx, cache, probs)
| 771 | probs = context_output['probs'] |
| 772 | |
| 773 | def body(ctx, cache, probs): |
| 774 | """ for whatever reason this didn't work when I ran it on more than one at once... ugh.""" |
| 775 | next_outputs = sample_step(ctx[:, -1][:, None], ignore_ids=ignore_ids, news_config=news_config, |
| 776 | batch_size=batch_size, p_for_topp=p_for_topp, cache=cache, |
| 777 | do_topk=do_topk) |
| 778 | |
| 779 | |
| 780 | # Update everything |
| 781 | new_cache = tf.concat([cache, next_outputs['new_cache']], axis=-2) |
| 782 | new_ids = tf.concat([ctx, next_outputs['new_tokens'][:, None]], axis=1) |
| 783 | new_probs = tf.concat([probs, next_outputs['new_probs'][:, None]], axis=1) |
| 784 | return [new_ids, new_cache, new_probs] |
| 785 | |
| 786 | def cond(ctx, cache, probs): |
| 787 | # ctx = tf.Print(ctx,[tf.shape(ctx)]) |
nothing calls this directly
no test coverage detected