same signature as sample_step
(initial_context, ignore_ids, news_config, p_for_topp=0.95, do_topk=False)
| 733 | |
| 734 | |
| 735 | def initialize_from_context(initial_context, ignore_ids, news_config, p_for_topp=0.95, do_topk=False): |
| 736 | """ same signature as sample_step""" |
| 737 | batch_size, _ = get_shape_list(initial_context, expected_rank=2) |
| 738 | |
| 739 | context_output = sample_step(tokens=initial_context, ignore_ids=ignore_ids, news_config=news_config, |
| 740 | batch_size=batch_size, p_for_topp=p_for_topp, cache=None, do_topk=do_topk) |
| 741 | return { |
| 742 | 'tokens': tf.concat([initial_context, context_output['new_tokens'][:, None]], 1), |
| 743 | 'cache': context_output['new_cache'], |
| 744 | 'probs': context_output['new_probs'][:, None] |
| 745 | } |
| 746 | |
| 747 | |
| 748 | def sample(news_config: GroverConfig, initial_context, eos_token, min_len, ignore_ids=None, p_for_topp=0.95, |
no test coverage detected