| 839 | |
| 840 | |
| 841 | def get_token_stream( |
| 842 | model, |
| 843 | context_tokens, |
| 844 | return_scores: bool = False, |
| 845 | prompt_length: int = None, |
| 846 | micro_batch_size: int = None, |
| 847 | bad_ids: List = None, |
| 848 | temperature: float = None, |
| 849 | topp: float = None, |
| 850 | topk: int = None, |
| 851 | ): |
| 852 | args = get_args() |
| 853 | tokenizer = get_tokenizer() |
| 854 | |
| 855 | context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args) |
| 856 | |
| 857 | context_tokens_tensor = torch.cuda.LongTensor(context_tokens) |
| 858 | context_length_tensor = torch.cuda.LongTensor(context_lengths) |
| 859 | |
| 860 | torch.distributed.broadcast( |
| 861 | context_length_tensor, |
| 862 | mpu.get_tensor_model_parallel_src_rank(), |
| 863 | group=mpu.get_tensor_model_parallel_group(), |
| 864 | ) |
| 865 | torch.distributed.broadcast( |
| 866 | context_tokens_tensor, |
| 867 | mpu.get_tensor_model_parallel_src_rank(), |
| 868 | group=mpu.get_tensor_model_parallel_group(), |
| 869 | ) |
| 870 | |
| 871 | context_length = context_length_tensor.min().item() |
| 872 | tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, micro_batch_size) |
| 873 | |
| 874 | batch_token_iterator = sample_sequence_batch( |
| 875 | model, |
| 876 | context_tokens_tensor, |
| 877 | context_length_tensor, |
| 878 | attention_mask, |
| 879 | position_ids, |
| 880 | return_scores=return_scores, |
| 881 | prompt_length=prompt_length, |
| 882 | bad_ids=bad_ids, |
| 883 | temperature=temperature, |
| 884 | topp=topp, |
| 885 | topk=topk, |
| 886 | ) |
| 887 | |
| 888 | if args.beam_search: |
| 889 | for beams in batch_token_iterator: |
| 890 | yield beams |
| 891 | else: |
| 892 | for tokens, lengths in batch_token_iterator: |
| 893 | context_length += 1 |
| 894 | if tokens is not None: |
| 895 | yield tokens[:, :context_length], lengths |
| 896 | else: |
| 897 | yield None, None |
| 898 | |