MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / get_beam_width_array

Function get_beam_width_array

examples/utils.py:258–289  ·  view source on GitHub ↗
(bwa: str = None)

Source from the content-addressed store, hash-verified

256
257
258def get_beam_width_array(bwa: str = None):
259 bwa = ast.literal_eval(bwa) # Short for "beam_width_array"
260 if isinstance(bwa, str):
261 bwa = ast.literal_eval(bwa) # parse again for string
262
263 def parse_one_bwa(row):
264 assert isinstance(row, list), f"Beam width array must be a list."
265 assert len(
266 row
267 ) <= 8, "Length of beam width array must not be greater than 8 now."
268 assert all([isinstance(beam, int) for beam in row
269 ]), "Numbers in beam width array must be integer."
270 bwa_tensor = torch.zeros([8], dtype=torch.int32)
271 for j in range(len(row)):
272 bwa_tensor[j] = row[j]
273 bwa_tensor[len(row):] = row[-1]
274 return bwa_tensor, max(row)
275
276 if isinstance(bwa, list): # Only one BWA
277 bwa_tensor, max_beam_width = parse_one_bwa(bwa)
278 elif isinstance(bwa, tuple): # BWA for respective requests
279 bwa_tensor_list = []
280 max_beam_width = 0
281 for row in bwa:
282 bwa_tensor, beam_width = parse_one_bwa(row)
283 bwa_tensor_list.append(bwa_tensor)
284 max_beam_width = max(max_beam_width, beam_width)
285 bwa_tensor = torch.stack(bwa_tensor_list, dim=0)
286 else:
287 raise ValueError(f"Invalid beam width array: {bwa}")
288
289 return bwa_tensor.tolist(), max_beam_width
290
291
292def add_common_args(parser):

Callers 2

mainFunction · 0.90
mainFunction · 0.90

Calls 3

parse_one_bwaFunction · 0.85
maxFunction · 0.85
appendMethod · 0.45

Tested by

no test coverage detected