Add an operation to create a shape tensor. The shape tensor can either be the shape of the input tensor when the parameter dim is None or a scalar (tensor of rank 0) that corresponds to the size of dim-th dimension. Parameters: input : Tensor The input tens
(input: Tensor,
dim: Optional[int] = None,
cast_to_dtype: Optional[Union[str, trt.DataType]] = None,
clip_before_cast: Sequence[int] = None)
| 2054 | # If dim is None, return a 1-D TensorRT LLM tensor of the size |
| 2055 | # If dim is not None, return a 0-D TensorRT LLM tensor of the dimension size |
| 2056 | def shape(input: Tensor, |
| 2057 | dim: Optional[int] = None, |
| 2058 | cast_to_dtype: Optional[Union[str, trt.DataType]] = None, |
| 2059 | clip_before_cast: Sequence[int] = None) -> Tensor: |
| 2060 | ''' |
| 2061 | Add an operation to create a shape tensor. |
| 2062 | |
| 2063 | The shape tensor can either be the shape of the input tensor when the |
| 2064 | parameter dim is None or a scalar (tensor of rank 0) that corresponds to |
| 2065 | the size of dim-th dimension. |
| 2066 | |
| 2067 | Parameters: |
| 2068 | input : Tensor |
| 2069 | The input tensor from which we want to extract the shape or the |
| 2070 | size in one dimension. |
| 2071 | |
| 2072 | dim : Optional[int] |
| 2073 | The dimension from which to extract the size. If it is None, the |
| 2074 | entire shape of the input tensor is returned. |
| 2075 | |
| 2076 | Returns: |
| 2077 | A tensor that contains the shape of the input tensor (if 'dim' is None) |
| 2078 | or the size in the dimension 'dim' of the input tensor. If 'dim' is |
| 2079 | 'None', that tensor has the same rank as the input tensor, otherwise |
| 2080 | its rank is 0. |
| 2081 | ''' |
| 2082 | layer = default_trtnet().add_shape(input.trt_tensor) |
| 2083 | res = _create_tensor(layer.get_output(0), layer) |
| 2084 | if cast_to_dtype is not None: |
| 2085 | if clip_before_cast is not None and (cast_to_dtype == 'int32' |
| 2086 | or cast_to_dtype == trt.int32): |
| 2087 | assert len( |
| 2088 | clip_before_cast |
| 2089 | ) == 2, f"This parameter only expects a tuple of 2 integers (lower, upper) but got {clip_before_cast}" |
| 2090 | res = int_clip(res, clip_before_cast[0], clip_before_cast[1]) |
| 2091 | res = cast(res, cast_to_dtype) |
| 2092 | |
| 2093 | if dim is None: |
| 2094 | return res |
| 2095 | |
| 2096 | return gather(res, dim=0, indices=dim).view([]) |
| 2097 | |
| 2098 | |
| 2099 | def gather(input: Tensor, dim: int, indices: Union[Tensor, int]) -> Tensor: |