MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / constant_to_tensor_

Function constant_to_tensor_

tensorrt_llm/functional.py:2854–2881  ·  view source on GitHub ↗
(input: Union[Tensor, int, float, bool],
                        dtype: Union[trt.DataType, str] = None,
                        to_array=True)

Source from the content-addressed store, hash-verified

2852
2853
2854def constant_to_tensor_(input: Union[Tensor, int, float, bool],
2855 dtype: Union[trt.DataType, str] = None,
2856 to_array=True) -> Tensor:
2857 if dtype is None:
2858 # deduce the type from the given value
2859 # NOTE: bool is a subtype of int, so bool needs to be checked first
2860 if isinstance(input, bool):
2861 dtype = trt.bool
2862 elif isinstance(input, int):
2863 dtype = trt.int32
2864 else:
2865 dtype = trt.float32
2866
2867 if not isinstance(input, Tensor):
2868 if isinstance(dtype, str):
2869 dtype = str_dtype_to_trt(dtype)
2870 array_fn_dict = {
2871 trt.int64: int64_array,
2872 trt.int32: int32_array,
2873 trt.float32: fp32_array,
2874 trt.float16: fp16_array,
2875 trt.bfloat16: bf16_array,
2876 trt.bool: bool_array,
2877 }
2878 assert dtype in array_fn_dict
2879 return constant(array_fn_dict[dtype]([input] if to_array else input))
2880
2881 return input
2882
2883
2884def constants_to_tensors_(

Callers 15

_get_draft_token_arrayFunction · 0.90
_get_maskFunction · 0.90
warp_logitsFunction · 0.90
_beam_search_candidatesFunction · 0.90
_top_1_logitsFunction · 0.90
encode_textMethod · 0.90
forwardMethod · 0.90
padFunction · 0.85
arangeFunction · 0.85
cumsumFunction · 0.85

Calls 2

str_dtype_to_trtFunction · 0.85
constantFunction · 0.85

Tested by

no test coverage detected