MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / cast

Function cast

tensorrt_llm/functional.py:876–919  ·  view source on GitHub ↗

Add a cast operation. For an input tensor of type INT8, this function sets the dynamic range of the input to [-127, 127] for automatic dequantization. For a cast into INT8, that function sets the dynamic range of the output to [-127, 127] for automatic quantization. Parame

(input: Tensor, dtype: Union[str, trt.DataType])

Source from the content-addressed store, hash-verified

874
875
876def cast(input: Tensor, dtype: Union[str, trt.DataType]):
877 '''
878 Add a cast operation.
879
880 For an input tensor of type INT8, this function sets the dynamic range of
881 the input to [-127, 127] for automatic dequantization. For a cast into
882 INT8, that function sets the dynamic range of the output to [-127, 127] for
883 automatic quantization.
884
885 Parameters:
886 input : Tensor
887 The input tensor on which the cast is applied.
888
889 dtype : str or trt.DataType
890 The data type of the output tensor after the cast. When 'dtype' is
891 provided as a string, it must be a name amongst the valid names.
892 See _str_to_trt_dtype_dict in _utils.py for a list of supported
893 types and type names.
894
895 Returns:
896 The tensor produced by the inserted layer.
897 '''
898 if isinstance(dtype, str):
899 cvt_dtype = str_dtype_to_trt(dtype)
900 elif isinstance(dtype, trt.DataType):
901 cvt_dtype = dtype
902 else:
903 raise TypeError("%s is not supported" % type(dtype))
904
905 if input.dtype == cvt_dtype:
906 # If input type and cast dtype are the same, do nothing
907 return input
908
909 layer = default_trtnet().add_cast(input.trt_tensor, cvt_dtype)
910 if not default_net().strongly_typed:
911 layer.set_output_type(0, cvt_dtype)
912 output = _create_tensor(layer.get_output(0), layer)
913 if not default_net().strongly_typed:
914 if input.dtype == str_dtype_to_trt('int8'):
915 layer.get_input(0).set_dynamic_range(-127, 127)
916 if cvt_dtype == str_dtype_to_trt('int8'):
917 layer.get_output(0).set_dynamic_range(-127, 127)
918
919 return output
920
921
922def flip(input: Tensor, dims: Sequence[int]) -> Tensor:

Callers 15

forwardMethod · 0.90
_validate_draft_tokensFunction · 0.90
_get_draft_token_indicesFunction · 0.90
_get_draft_token_arrayFunction · 0.90
_get_maskFunction · 0.90
_top_1_logitsFunction · 0.90
_prepare_drafter_inputFunction · 0.90
forwardMethod · 0.90
test_llm_apiMethod · 0.85

Calls 6

str_dtype_to_trtFunction · 0.85
default_trtnetFunction · 0.85
default_netFunction · 0.85
_create_tensorFunction · 0.85
get_outputMethod · 0.45
get_inputMethod · 0.45

Tested by 8

test_llm_apiMethod · 0.68
context_requestsMethod · 0.68
generation_requestsMethod · 0.68
_uutFunction · 0.68
_build_mock_requestsMethod · 0.68
_uut_providerMethod · 0.68