(
model,
context_tokens,
context_lengths,
attention_mask,
position_ids,
maxlen=None,
type_ids=None,
return_scores: bool = False,
prompt_length: int = None,
bad_ids: List = None,
temperature: float = None,
topp: float = None,
topk: int = None,
)
| 903 | |
| 904 | |
| 905 | def sample_sequence_batch( |
| 906 | model, |
| 907 | context_tokens, |
| 908 | context_lengths, |
| 909 | attention_mask, |
| 910 | position_ids, |
| 911 | maxlen=None, |
| 912 | type_ids=None, |
| 913 | return_scores: bool = False, |
| 914 | prompt_length: int = None, |
| 915 | bad_ids: List = None, |
| 916 | temperature: float = None, |
| 917 | topp: float = None, |
| 918 | topk: int = None, |
| 919 | ): |
| 920 | args = get_args() |
| 921 | tokenizer = get_tokenizer() |
| 922 | temperature = temperature if temperature is not None else args.temperature |
| 923 | topp = topp if topp is not None else args.top_p |
| 924 | topk = topk if topk is not None else args.top_k |
| 925 | |
| 926 | model.eval() |
| 927 | with torch.no_grad(): |
| 928 | context_length = context_lengths.min().item() |
| 929 | |
| 930 | # added eos_id to support the function generate_samples_eval that passes |
| 931 | # eos_id as an argument and needs termination when that id id found. |
| 932 | if hasattr(args, "eos_id"): |
| 933 | eos_id = args.eos_id |
| 934 | else: |
| 935 | eos_id = tokenizer.eod |
| 936 | |
| 937 | counter = 0 |
| 938 | org_context_length = context_length |
| 939 | |
| 940 | layer_past = None |
| 941 | batch_size = context_tokens.size(0) |
| 942 | is_done = torch.zeros([batch_size]).byte().cuda() |
| 943 | tokens = context_tokens |
| 944 | if maxlen is None: |
| 945 | maxlen = args.seq_length - 1 |
| 946 | if maxlen > (org_context_length + args.out_seq_length): |
| 947 | maxlen = org_context_length + args.out_seq_length |
| 948 | |
| 949 | lengths = torch.ones([batch_size]).long().cuda() * maxlen |
| 950 | if return_scores: |
| 951 | scores = torch.zeros([batch_size]).float().cuda() |
| 952 | |
| 953 | if args.beam_search: |
| 954 | beams = beam_search(model, context_tokens=tokens.cpu().numpy().tolist()[0][:context_length], |
| 955 | num_beams=args.num_beams) |
| 956 | if args.beam_warmup: |
| 957 | beam = beams[0] |
| 958 | tokens_ = beam.tokens |
| 959 | tokens_ = (tokens_ if tokens_[-1] != tokenizer.eod else tokens_[:-1]) |
| 960 | tokens_warmup = [] |
| 961 | for i in range(batch_size): |
| 962 | tokens_warmup.append(tokens_.copy()) |
no test coverage detected