MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / TensorWrapper

Class TensorWrapper

tensorrt_llm/_utils.py:984–1051  ·  view source on GitHub ↗

A wrapper wraps raw data pointer to a tensor-like object. Could be compatibale with openai triton kernel and be converted to `torch.Tensor` with zero-copy overhead.

Source from the content-addressed store, hash-verified

982
983
984class TensorWrapper:
985 """
986 A wrapper wraps raw data pointer to a tensor-like object. Could be compatibale with openai triton kernel and be converted to `torch.Tensor` with zero-copy overhead.
987 """
988
989 def __init__(
990 self,
991 data_ptr: int,
992 dtype: Union[torch.dtype, str, np.dtype, trt.DataType],
993 shape: Sequence[int],
994 strides: Optional[Sequence[int]] = None,
995 ):
996 assert isinstance(data_ptr, int)
997 self._data_ptr = data_ptr
998 self.dtype = dtype
999 self.shape = shape
1000 self.strides = strides
1001
1002 def data_ptr(self):
1003 return self._data_ptr
1004
1005 @property
1006 def dtype(self):
1007 return self._dtype
1008
1009 @property
1010 def shape(self):
1011 return getattr(self, "_shape", None)
1012
1013 @dtype.setter
1014 def dtype(self, dtype: Union[torch.dtype, str, np.dtype, trt.DataType]):
1015 if isinstance(dtype, torch.dtype):
1016 self._dtype = dtype
1017 elif isinstance(dtype, str):
1018 self._dtype = str_dtype_to_torch(dtype)
1019 elif isinstance(dtype, np.dtype):
1020 self._dtype = np_dtype_to_torch(dtype)
1021 elif isinstance(dtype, trt.DataType):
1022 self._dtype = trt_dtype_to_torch(dtype)
1023 else:
1024 raise TypeError(f"Unsupported dtype: {dtype}")
1025
1026 @shape.setter
1027 def shape(self, shape: Sequence[int]):
1028 self._shape = tuple(int(i) for i in shape)
1029
1030 def numel(self):
1031 return volume(self.shape)
1032
1033 @property
1034 def __cuda_array_interface__(self):
1035 return {
1036 "shape":
1037 self.shape,
1038 "typestr":
1039 torch_dtype_to_np_typestr(self.dtype),
1040 "data": (self.data_ptr() if self.numel() > 0 else 0, False),
1041 "strides": [

Callers 2

make_weak_refFunction · 0.90
from_trt_descMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected