MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX / get_token_stream

Function get_token_stream

codegeex/megatron/code_generation_utils.py:841–897  ·  view source on GitHub ↗
(
        model,
        context_tokens,
        return_scores: bool = False,
        prompt_length: int = None,
        micro_batch_size: int = None,
        bad_ids: List = None,
        temperature: float = None,
        topp: float = None,
        topk: int = None,
)

Source from the content-addressed store, hash-verified

839
840
841def get_token_stream(
842 model,
843 context_tokens,
844 return_scores: bool = False,
845 prompt_length: int = None,
846 micro_batch_size: int = None,
847 bad_ids: List = None,
848 temperature: float = None,
849 topp: float = None,
850 topk: int = None,
851):
852 args = get_args()
853 tokenizer = get_tokenizer()
854
855 context_tokens, context_lengths = pad_batch(context_tokens, tokenizer.eod, args)
856
857 context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
858 context_length_tensor = torch.cuda.LongTensor(context_lengths)
859
860 torch.distributed.broadcast(
861 context_length_tensor,
862 mpu.get_tensor_model_parallel_src_rank(),
863 group=mpu.get_tensor_model_parallel_group(),
864 )
865 torch.distributed.broadcast(
866 context_tokens_tensor,
867 mpu.get_tensor_model_parallel_src_rank(),
868 group=mpu.get_tensor_model_parallel_group(),
869 )
870
871 context_length = context_length_tensor.min().item()
872 tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, micro_batch_size)
873
874 batch_token_iterator = sample_sequence_batch(
875 model,
876 context_tokens_tensor,
877 context_length_tensor,
878 attention_mask,
879 position_ids,
880 return_scores=return_scores,
881 prompt_length=prompt_length,
882 bad_ids=bad_ids,
883 temperature=temperature,
884 topp=topp,
885 topk=topk,
886 )
887
888 if args.beam_search:
889 for beams in batch_token_iterator:
890 yield beams
891 else:
892 for tokens, lengths in batch_token_iterator:
893 context_length += 1
894 if tokens is not None:
895 yield tokens[:, :context_length], lengths
896 else:
897 yield None, None
898

Calls 5

get_argsFunction · 0.90
get_tokenizerFunction · 0.90
pad_batchFunction · 0.70
get_batchFunction · 0.70
sample_sequence_batchFunction · 0.70

Tested by 1

mainFunction · 0.72