MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / _unpack_beams

Function _unpack_beams

tensorrt_llm/models/redrafter/redrafter_helper.py:30–43  ·  view source on GitHub ↗

x: [bs, S, V] indices: [bs, nb, bl] output:

(x: Tensor, indices: Tensor, num_beams: int,
                  beam_length: int)

Source from the content-addressed store, hash-verified

28
29
30def _unpack_beams(x: Tensor, indices: Tensor, num_beams: int,
31 beam_length: int) -> Tensor:
32 """
33 x: [bs, S, V]
34 indices: [bs, nb, bl]
35 output:
36 """
37 assert x.rank() == 3
38 d0 = shape(x, 0, INT_DTYPE_STR)
39 dl = shape(x, -1, INT_DTYPE_STR)
40 indices = view(indices, [-1, num_beams * beam_length, 1], False)
41 res_shape = concat([d0, num_beams, beam_length, dl])
42 res = view(gather_nd(x, indices), res_shape, False) # [d0, nb, bl, dl]
43 return res
44
45
46def _validate_draft_tokens(draft_log_probs: Tensor,

Callers 1

_validate_draft_tokensFunction · 0.85

Calls 5

shapeFunction · 0.90
viewFunction · 0.90
concatFunction · 0.90
gather_ndFunction · 0.90
rankMethod · 0.45

Tested by

no test coverage detected