(
model,
tokenizer,
context_tokens,
context_lengths,
attention_mask,
position_ids,
seq_length,
out_seq_length,
maxlen=None,
return_scores: bool = False,
prompt_length: int = None,
bad_ids: List = None,
temperature: float = 1.0,
topp: float = 1.0,
topk: int = 0.0,
recompute: bool = False,
greedy: bool = False,
)
| 215 | |
| 216 | |
| 217 | def sample_sequence_batch( |
| 218 | model, |
| 219 | tokenizer, |
| 220 | context_tokens, |
| 221 | context_lengths, |
| 222 | attention_mask, |
| 223 | position_ids, |
| 224 | seq_length, |
| 225 | out_seq_length, |
| 226 | maxlen=None, |
| 227 | return_scores: bool = False, |
| 228 | prompt_length: int = None, |
| 229 | bad_ids: List = None, |
| 230 | temperature: float = 1.0, |
| 231 | topp: float = 1.0, |
| 232 | topk: int = 0.0, |
| 233 | recompute: bool = False, |
| 234 | greedy: bool = False, |
| 235 | ): |
| 236 | model.eval() |
| 237 | with torch.no_grad(): |
| 238 | context_length = context_lengths.min().item() |
| 239 | eos_id = tokenizer.eos_token_id |
| 240 | |
| 241 | counter = 0 |
| 242 | org_context_length = context_length |
| 243 | |
| 244 | layer_past = None |
| 245 | batch_size = context_tokens.size(0) |
| 246 | is_done = torch.zeros([batch_size]).byte().cuda() |
| 247 | tokens = context_tokens |
| 248 | if maxlen is None: |
| 249 | maxlen = seq_length - 1 |
| 250 | if maxlen > (org_context_length + out_seq_length): |
| 251 | maxlen = org_context_length + out_seq_length |
| 252 | |
| 253 | lengths = torch.ones([batch_size]).long().cuda() * maxlen |
| 254 | if return_scores: |
| 255 | scores = torch.zeros([batch_size]).float().cuda() |
| 256 | |
| 257 | while context_length <= (maxlen): |
| 258 | |
| 259 | if recompute: |
| 260 | logits = model(tokens, |
| 261 | position_ids, |
| 262 | attention_mask, |
| 263 | prompt_length=prompt_length, |
| 264 | context_length=context_length, |
| 265 | ) |
| 266 | logits = logits[:, context_length - 1, :] |
| 267 | else: |
| 268 | if counter == 0: |
| 269 | tokens2use = tokens[:, :context_length] |
| 270 | positions2use = position_ids[:, :context_length] |
| 271 | else: |
| 272 | tokens2use = tokens[:, context_length - 1].view( |
| 273 | batch_size, -1) |
| 274 | positions2use = position_ids[:, context_length - 1].view( |
no test coverage detected