MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / wrap_torch_tensor

Function wrap_torch_tensor

triton_kernels/tensor.py:204–209  ·  view source on GitHub ↗
(torch_tensor, dtype=None)

Source from the content-addressed store, hash-verified

202
203
204def wrap_torch_tensor(torch_tensor, dtype=None):
205 if dtype is None:
206 dtype = torch_tensor.dtype
207 shape = list(torch_tensor.shape)
208 shape[torch_tensor.stride().index(1)] *= bitwidth(torch_tensor.dtype) // bitwidth(dtype)
209 return Tensor(Storage(torch_tensor), dtype=dtype, shape=shape)
210
211
212def convert_layout(tensor: Tensor, layout_cls: Type[Layout], **layout_kwargs):

Callers 3

swizzle_weight_and_scaleFunction · 0.90
_swizzle_mxfp4Function · 0.90
matmul_ogsFunction · 0.85

Calls 4

bitwidthFunction · 0.85
StorageClass · 0.85
strideMethod · 0.80
TensorClass · 0.70

Tested by

no test coverage detected