MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / expand_dims_like

Function expand_dims_like

tensorrt_llm/functional.py:2010–2051  ·  view source on GitHub ↗

Add an operation to expand the first tensor to the same rank as the second tensor. That function takes a first tensor. It also accepts an integer or a float, in which case it creates a constant tensor from it. In both cases, the rank of that first tensor is compared to the rank

(left: Union[Tensor, int, float], right: Tensor)

Source from the content-addressed store, hash-verified

2008
2009
2010def expand_dims_like(left: Union[Tensor, int, float], right: Tensor) -> Tensor:
2011 '''
2012 Add an operation to expand the first tensor to the same rank as the second
2013 tensor.
2014
2015 That function takes a first tensor. It also accepts an integer or a float,
2016 in which case it creates a constant tensor from it. In both cases, the rank
2017 of that first tensor is compared to the rank of the second tensor. If they
2018 are of the same rank, the first tensor is returned. Otherwise, the first
2019 tensor is expanded on the left to match the rank of the second tensor.
2020
2021 Note that the shapes do not have to match, only the rank is considered in
2022 that function.
2023
2024 For example, for a pair of tensors of shapes [3, 4] and [4, 3, 2], the
2025 first tensor will be expanded to a tensor of rank 3 and shape [1, 3, 4].
2026
2027 Parameters:
2028 left : Union[Tensor, int, float]
2029 The first tensor to expand. When a scalar value is provided as a
2030 parameter, that function first creates a tensor before expanding it
2031 (if needed).
2032
2033 right : Tensor
2034 The reference tensor to match.
2035
2036 Returns:
2037 The tensor produced by the shuffle layer.
2038 '''
2039 if isinstance(left, int):
2040 left = constant(dims_array([left]))
2041 elif isinstance(left, float):
2042 if isinstance(right, Tensor) and right.dtype == trt.DataType.HALF:
2043 left = constant(fp16_array([left]))
2044 else:
2045 left = constant(fp32_array([left]))
2046 left_ndim = left.ndim()
2047 right_ndim = right.ndim()
2048 if right_ndim > left_ndim:
2049 new_ndim = list(range(right_ndim - left_ndim))
2050 return expand_dims(left, new_ndim)
2051 return left
2052
2053
2054# If dim is None, return a 1-D TensorRT LLM tensor of the size

Callers 2

broadcast_helperFunction · 0.85
whereFunction · 0.85

Calls 4

constantFunction · 0.85
dims_arrayFunction · 0.85
expand_dimsFunction · 0.85
ndimMethod · 0.45

Tested by

no test coverage detected