(input: Union[Tensor, int, float, bool],
dtype: Union[trt.DataType, str] = None,
to_array=True)
| 2852 | |
| 2853 | |
| 2854 | def constant_to_tensor_(input: Union[Tensor, int, float, bool], |
| 2855 | dtype: Union[trt.DataType, str] = None, |
| 2856 | to_array=True) -> Tensor: |
| 2857 | if dtype is None: |
| 2858 | # deduce the type from the given value |
| 2859 | # NOTE: bool is a subtype of int, so bool needs to be checked first |
| 2860 | if isinstance(input, bool): |
| 2861 | dtype = trt.bool |
| 2862 | elif isinstance(input, int): |
| 2863 | dtype = trt.int32 |
| 2864 | else: |
| 2865 | dtype = trt.float32 |
| 2866 | |
| 2867 | if not isinstance(input, Tensor): |
| 2868 | if isinstance(dtype, str): |
| 2869 | dtype = str_dtype_to_trt(dtype) |
| 2870 | array_fn_dict = { |
| 2871 | trt.int64: int64_array, |
| 2872 | trt.int32: int32_array, |
| 2873 | trt.float32: fp32_array, |
| 2874 | trt.float16: fp16_array, |
| 2875 | trt.bfloat16: bf16_array, |
| 2876 | trt.bool: bool_array, |
| 2877 | } |
| 2878 | assert dtype in array_fn_dict |
| 2879 | return constant(array_fn_dict[dtype]([input] if to_array else input)) |
| 2880 | |
| 2881 | return input |
| 2882 | |
| 2883 | |
| 2884 | def constants_to_tensors_( |
no test coverage detected