MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / SymTensor

Class SymTensor

tensorrt_llm/python_plugin.py:231–274  ·  view source on GitHub ↗

The class to represent symbolic tensors. Only contains dtype and shape information for users to write their own shape/dtype inference function.

Source from the content-addressed store, hash-verified

229
230
231class SymTensor:
232 """The class to represent symbolic tensors.
233
234 Only contains dtype and shape information for users to write their own shape/dtype inference function.
235 """
236
237 def __init__(
238 self,
239 dtype: Union[torch.dtype, np.dtype, str, trt.DataType, Type[None]],
240 shape: Union[ShapeExpr, Sequence[int]],
241 ):
242 self.dtype = dtype
243 self.shape = shape
244
245 @property
246 def shape(self) -> Union[ShapeExpr, Sequence[int]]:
247 return self._shape
248
249 @shape.setter
250 def shape(self, shape: Union[ShapeExpr, Sequence[int]]):
251 assert isinstance(shape, (ShapeExpr, list, tuple))
252 if isinstance(shape, (list, tuple)):
253 for i in shape:
254 assert isinstance(i, int)
255 self._shape = shape
256
257 @property
258 def dtype(self) -> Union[trt.DataType, Type[None]]:
259 return self._dtype
260
261 @dtype.setter
262 def dtype(self, dtype: Union[torch.dtype, str, np.dtype, trt.DataType, Type[None]]):
263 if isinstance(dtype, torch.dtype):
264 self._dtype = torch_dtype_to_trt(dtype)
265 elif isinstance(dtype, str):
266 self._dtype = str_dtype_to_trt(dtype)
267 elif isinstance(dtype, np.dtype):
268 self._dtype = np_dtype_to_trt(dtype)
269 elif isinstance(dtype, trt.DataType):
270 self._dtype = dtype
271 elif dtype is None:
272 self._dtype = None
273 else:
274 raise TypeError(f"Unsupported dtype: {dtype}")
275
276
277def _convert_return_value_to_list(ret):

Callers 4

shape_dtype_inferenceMethod · 0.90
get_output_data_typesMethod · 0.85
get_output_shapesMethod · 0.85
__call__Method · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected