x: [bs, S, V] indices: [bs, nb, bl] output:
(x: Tensor, indices: Tensor, num_beams: int,
beam_length: int)
| 28 | |
| 29 | |
| 30 | def _unpack_beams(x: Tensor, indices: Tensor, num_beams: int, |
| 31 | beam_length: int) -> Tensor: |
| 32 | """ |
| 33 | x: [bs, S, V] |
| 34 | indices: [bs, nb, bl] |
| 35 | output: |
| 36 | """ |
| 37 | assert x.rank() == 3 |
| 38 | d0 = shape(x, 0, INT_DTYPE_STR) |
| 39 | dl = shape(x, -1, INT_DTYPE_STR) |
| 40 | indices = view(indices, [-1, num_beams * beam_length, 1], False) |
| 41 | res_shape = concat([d0, num_beams, beam_length, dl]) |
| 42 | res = view(gather_nd(x, indices), res_shape, False) # [d0, nb, bl, dl] |
| 43 | return res |
| 44 | |
| 45 | |
| 46 | def _validate_draft_tokens(draft_log_probs: Tensor, |