| 136 | return sentence_ids.reshape([-1,]).tolist()[tokenized['input_ids'].shape[1]:] |
| 137 | |
| 138 | def sample_top_k(scores, top_k): |
| 139 | top_k = min(top_k, scores.size(-1)) # Safety check |
| 140 | # Remove all tokens with a probability less than the last token of the top-k |
| 141 | indices_to_remove = scores < jt.topk(scores, top_k)[0][..., -1, None] |
| 142 | scores = scores.masked_fill(indices_to_remove, -float("Inf")) |
| 143 | |
| 144 | return scores |
| 145 | |
| 146 | def sample_top_p(scores, top_p): |
| 147 | sorted_logits, sorted_indices = jt.sort(scores, descending=False) |