MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / rand

Function rand

tensorrt_llm/functional.py:1401–1433  ·  view source on GitHub ↗

This operation adds a fill layer that generates a random (uniform) tensor with the specified shape and data type. Parameters: shape: Tensor The shape of the tensor needed to be generated. low: float The minimum value (inclusive) of the range used for

(shape: Tensor,
         low: float = 0,
         high: float = 1,
         dtype: Union[str, trt.DataType] = 'float32')

Source from the content-addressed store, hash-verified

1399
1400
1401def rand(shape: Tensor,
1402 low: float = 0,
1403 high: float = 1,
1404 dtype: Union[str, trt.DataType] = 'float32') -> Tensor:
1405 '''
1406 This operation adds a fill layer that generates a random (uniform) tensor with the specified shape and data type.
1407
1408 Parameters:
1409 shape: Tensor
1410 The shape of the tensor needed to be generated.
1411 low: float
1412 The minimum value (inclusive) of the range used for random.
1413 high: float
1414 The maximum value (inclusive) of the range used for random.
1415 dtype: Union[str, trt.DataType]
1416 The desired data type for the output tensor.
1417 Returns:
1418 The generated random tensor produced by the fill layer.
1419 '''
1420 # NOTE: DISABLED FOR NOW UNTIL THE FILL LAYER (RANDOM_UNIFORM) in TRT IS FIXED
1421 assert False, "The rand() op is temporarily disabled."
1422 low = constant(fp32_array(low))
1423 high = constant(fp32_array(high))
1424 trt_dtype = dtype if isinstance(dtype,
1425 trt.DataType) else str_dtype_to_trt(dtype)
1426
1427 layer = default_trtnet().add_fill([0], trt.FillOperation.RANDOM_UNIFORM,
1428 trt_dtype)
1429
1430 layer.set_input(0, shape.trt_tensor)
1431 layer.set_input(1, low.trt_tensor)
1432 layer.set_input(2, high.trt_tensor)
1433 return _create_tensor(layer.get_output(0), layer)
1434
1435
1436def categorical_sample(probs: Tensor, rand_data: Tensor = None) -> Tensor:

Callers 2

_validate_draft_tokensFunction · 0.90
categorical_sampleFunction · 0.85

Calls 5

constantFunction · 0.85
str_dtype_to_trtFunction · 0.85
default_trtnetFunction · 0.85
_create_tensorFunction · 0.85
get_outputMethod · 0.45

Tested by

no test coverage detected