MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / interpolate

Function interpolate

tensorrt_llm/functional.py:966–1053  ·  view source on GitHub ↗
(input: Tensor,
                size: Union[int, List[int]] = None,
                scale_factor: Union[float, List[float]] = None,
                mode: str = 'nearest',
                align_corners: bool = False,
                recompute_scale_factor: bool = False,
                antialias: bool = False)

Source from the content-addressed store, hash-verified

964
965
966def interpolate(input: Tensor,
967 size: Union[int, List[int]] = None,
968 scale_factor: Union[float, List[float]] = None,
969 mode: str = 'nearest',
970 align_corners: bool = False,
971 recompute_scale_factor: bool = False,
972 antialias: bool = False) -> Tensor:
973 ##
974 ## TODO: Document that function!
975 ##
976
977 assert not input.is_dynamic()
978
979 input_ndim = input.ndim()
980
981 assert 2 < input_ndim < 6, "Only 3D, 4D and 5D input Tensors supported"
982 assert (size is not None) ^ (
983 scale_factor
984 is not None), "Only one of out_shape or scales should be defined"
985
986 assert mode in ('nearest', 'linear', 'bilinear', 'bicubic', 'trilinear',
987 'nearest-exact')
988
989 if mode == 'trilinear' and input_ndim != 5:
990 raise ValueError("trilinear only supports 5D tensor")
991
992 if mode == "bilinear" and input_ndim != 4:
993 raise ValueError("bilinear only supports 4D tensor")
994
995 if mode == "linear" and input_ndim != 3:
996 raise ValueError("linear only supports 3D tensor")
997
998 layer = default_trtnet().add_resize(input.trt_tensor)
999
1000 input_shape = input.size()
1001
1002 updated_shape = []
1003 if scale_factor:
1004 scale_len = 1 if isinstance(scale_factor,
1005 (float, int)) else len(scale_factor)
1006 if scale_len == 1 and isinstance(scale_factor, (float, int)):
1007 updated_scale = [scale_factor for _ in range(input_ndim - 2)]
1008
1009 else:
1010 updated_scale = scale_factor
1011 updated_shape = [
1012 int(math.floor(updated_scale[i - 2] *
1013 input_shape[i])) if i > 1 else input_shape[i]
1014 for i in range(input_ndim)
1015 ]
1016
1017 else:
1018 size_len = 1 if isinstance(size, int) else len(size)
1019 assert size_len == input_ndim - 2
1020 if size_len == 1 and isinstance(size, int):
1021 updated_size = [size for _ in range(input_ndim - 2)]
1022 else:
1023 updated_size = size

Callers 1

forwardMethod · 0.85

Calls 6

default_trtnetFunction · 0.85
_create_tensorFunction · 0.85
is_dynamicMethod · 0.80
ndimMethod · 0.45
sizeMethod · 0.45
get_outputMethod · 0.45

Tested by

no test coverage detected