MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / Tensor

Class Tensor

triton_kernels/tensor.py:101–166  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

99
100@dataclass
101class Tensor:
102 storage: Storage | torch.Tensor
103 dtype: IntegerType | FloatType | torch.dtype = None
104 shape: list[int] | None = None
105 shape_max: list[int] | None = None
106
107 def __post_init__(self):
108 # set storage
109 if isinstance(self.storage, torch.Tensor):
110 self.storage = Storage(self.storage)
111 # initialize dtype
112 if self.dtype is None:
113 self.dtype = self.storage.data.dtype
114 if bitwidth(self.dtype) < 8 and self.shape is None:
115 raise ValueError("shape must be provided for sub-byte types")
116 # initialize shape
117 if self.shape is None:
118 self.shape = list(self.storage.data.shape)
119 # validate shape: all elements must be `int` or numel-1 `torch.Tensor`
120 is_int = lambda s: isinstance(s, int)
121 is_item = lambda s: hasattr(s, "numel") and s.numel() == 1
122 assert all(map(lambda s: is_int(s) or is_item(s), self.shape))
123 # initialize shape_max
124 if self.shape_max is None:
125 self.shape_max = [None] * len(self.shape)
126 for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)):
127 if smax is not None and not is_int(smax):
128 raise ValueError(f"shape_max[{i}] must be `int` or `None`; got {type(smax)}")
129 if smax is None:
130 self.shape_max[i] = s
131 # validate shape_max: all elements must be `int`
132 assert all(map(is_int, self.shape_max))
133
134 # torch compatibility layer
135 @property
136 def ndim(self):
137 return len(self.shape)
138
139 @property
140 def device(self):
141 return self.storage.device
142
143 def stride(self, i=None):
144 return self.storage.data.stride() if i is None else self.storage.data.stride(i)
145
146 def data_ptr(self):
147 return self.storage.data.data_ptr()
148
149 def numel(self):
150 return self.storage.data.numel()
151
152 def element_size(self):
153 return bitwidth(self.dtype) // 8
154
155 @property
156 def data(self):
157 t = self.storage
158 return t.data if isinstance(t, Storage) else t

Callers 4

topk_forwardFunction · 0.90
matmul_ogsFunction · 0.70
wrap_torch_tensorFunction · 0.70
convert_layoutFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected