A wrapper wraps raw data pointer to a tensor-like object. Could be compatibale with openai triton kernel and be converted to `torch.Tensor` with zero-copy overhead.
| 982 | |
| 983 | |
| 984 | class TensorWrapper: |
| 985 | """ |
| 986 | A wrapper wraps raw data pointer to a tensor-like object. Could be compatibale with openai triton kernel and be converted to `torch.Tensor` with zero-copy overhead. |
| 987 | """ |
| 988 | |
| 989 | def __init__( |
| 990 | self, |
| 991 | data_ptr: int, |
| 992 | dtype: Union[torch.dtype, str, np.dtype, trt.DataType], |
| 993 | shape: Sequence[int], |
| 994 | strides: Optional[Sequence[int]] = None, |
| 995 | ): |
| 996 | assert isinstance(data_ptr, int) |
| 997 | self._data_ptr = data_ptr |
| 998 | self.dtype = dtype |
| 999 | self.shape = shape |
| 1000 | self.strides = strides |
| 1001 | |
| 1002 | def data_ptr(self): |
| 1003 | return self._data_ptr |
| 1004 | |
| 1005 | @property |
| 1006 | def dtype(self): |
| 1007 | return self._dtype |
| 1008 | |
| 1009 | @property |
| 1010 | def shape(self): |
| 1011 | return getattr(self, "_shape", None) |
| 1012 | |
| 1013 | @dtype.setter |
| 1014 | def dtype(self, dtype: Union[torch.dtype, str, np.dtype, trt.DataType]): |
| 1015 | if isinstance(dtype, torch.dtype): |
| 1016 | self._dtype = dtype |
| 1017 | elif isinstance(dtype, str): |
| 1018 | self._dtype = str_dtype_to_torch(dtype) |
| 1019 | elif isinstance(dtype, np.dtype): |
| 1020 | self._dtype = np_dtype_to_torch(dtype) |
| 1021 | elif isinstance(dtype, trt.DataType): |
| 1022 | self._dtype = trt_dtype_to_torch(dtype) |
| 1023 | else: |
| 1024 | raise TypeError(f"Unsupported dtype: {dtype}") |
| 1025 | |
| 1026 | @shape.setter |
| 1027 | def shape(self, shape: Sequence[int]): |
| 1028 | self._shape = tuple(int(i) for i in shape) |
| 1029 | |
| 1030 | def numel(self): |
| 1031 | return volume(self.shape) |
| 1032 | |
| 1033 | @property |
| 1034 | def __cuda_array_interface__(self): |
| 1035 | return { |
| 1036 | "shape": |
| 1037 | self.shape, |
| 1038 | "typestr": |
| 1039 | torch_dtype_to_np_typestr(self.dtype), |
| 1040 | "data": (self.data_ptr() if self.numel() > 0 else 0, False), |
| 1041 | "strides": [ |
no outgoing calls
no test coverage detected