MCPcopy Index your code
hub / github.com/zai-org/CodeGeeX / sample_sequence_batch

Function sample_sequence_batch

codegeex/megatron/code_generation_utils.py:905–1080  ·  view source on GitHub ↗
(
        model,
        context_tokens,
        context_lengths,
        attention_mask,
        position_ids,
        maxlen=None,
        type_ids=None,
        return_scores: bool = False,
        prompt_length: int = None,
        bad_ids: List = None,
        temperature: float = None,
        topp: float = None,
        topk: int = None,
)

Source from the content-addressed store, hash-verified

903
904
905def sample_sequence_batch(
906 model,
907 context_tokens,
908 context_lengths,
909 attention_mask,
910 position_ids,
911 maxlen=None,
912 type_ids=None,
913 return_scores: bool = False,
914 prompt_length: int = None,
915 bad_ids: List = None,
916 temperature: float = None,
917 topp: float = None,
918 topk: int = None,
919):
920 args = get_args()
921 tokenizer = get_tokenizer()
922 temperature = temperature if temperature is not None else args.temperature
923 topp = topp if topp is not None else args.top_p
924 topk = topk if topk is not None else args.top_k
925
926 model.eval()
927 with torch.no_grad():
928 context_length = context_lengths.min().item()
929
930 # added eos_id to support the function generate_samples_eval that passes
931 # eos_id as an argument and needs termination when that id id found.
932 if hasattr(args, "eos_id"):
933 eos_id = args.eos_id
934 else:
935 eos_id = tokenizer.eod
936
937 counter = 0
938 org_context_length = context_length
939
940 layer_past = None
941 batch_size = context_tokens.size(0)
942 is_done = torch.zeros([batch_size]).byte().cuda()
943 tokens = context_tokens
944 if maxlen is None:
945 maxlen = args.seq_length - 1
946 if maxlen > (org_context_length + args.out_seq_length):
947 maxlen = org_context_length + args.out_seq_length
948
949 lengths = torch.ones([batch_size]).long().cuda() * maxlen
950 if return_scores:
951 scores = torch.zeros([batch_size]).float().cuda()
952
953 if args.beam_search:
954 beams = beam_search(model, context_tokens=tokens.cpu().numpy().tolist()[0][:context_length],
955 num_beams=args.num_beams)
956 if args.beam_warmup:
957 beam = beams[0]
958 tokens_ = beam.tokens
959 tokens_ = (tokens_ if tokens_[-1] != tokenizer.eod else tokens_[:-1])
960 tokens_warmup = []
961 for i in range(batch_size):
962 tokens_warmup.append(tokens_.copy())

Callers 1

get_token_streamFunction · 0.70

Calls 9

get_argsFunction · 0.90
get_tokenizerFunction · 0.90
beam_searchFunction · 0.85
sizeMethod · 0.80
pad_batchFunction · 0.70
get_batchFunction · 0.70
top_k_logitsFunction · 0.70
switchFunction · 0.70
evalMethod · 0.45

Tested by

no test coverage detected