MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / Storage

Class Storage

triton_kernels/tensor.py:18–72  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

16
17@dataclass
18class Storage:
19 data: torch.Tensor
20 layout: Layout = None
21
22 def __post_init__(self):
23 assert isinstance(self.data, torch.Tensor)
24 if self.layout is None:
25 self.layout = StridedLayout(self.data.shape)
26
27 @property
28 def device(self):
29 return self.data.device
30
31 def is_tma_compliant(self):
32 # TMAs didn't exist until Hopper
33 if not cuda_capability_geq(9, 0):
34 return False
35 # TMAs only exist for 2D, 3D, 5D inputs
36 if len(self.data.shape) not in [2, 3, 5]:
37 return False
38 # TMAs need at most one stride equal to 1
39 # and all other strides divisble by 16
40 strides = list(self.data.stride())
41 try:
42 major_dim = strides.index(1)
43 except ValueError:
44 major_dim = -1
45 ndim = self.data.ndim
46 bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8
47 compliant = [strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim]
48 return all(compliant)
49
50 def make_dense_tma(self, block_shape, transpose=False):
51 strides = list(self.data.stride())
52 shape = list(self.data.shape)
53 transpose = self.data.stride()[-1] != 1
54 if transpose:
55 block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]]
56 shape = shape[:-2] + [shape[-1], shape[-2]]
57 strides = strides[:-2] + [strides[-1], strides[-2]]
58 if self.data.dtype == torch.uint8 and self.layout.name == "BLACKWELL_VALUE":
59 indx = strides.index(1)
60 block_shape[indx] = block_shape[indx] // 2
61 if shape[-1] % 128 != 0:
62 raise ValueError("inner shape need to be multiple of 128 for "
63 "mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs.")
64 block_shape = self.layout.swizzle_block_shape(block_shape)
65 return TensorDescriptor(self.data, shape, strides, block_shape)
66
67 def make_tma(self, block_shape, mode, transpose=False):
68 if mode in ["dense", "gather", "scatter"]:
69 return self.make_dense_tma(block_shape, transpose)
70 assert mode == "ragged"
71 ragged_dim = len(self.data.shape) - 2
72 return create_ragged_descriptor(self.data, block_shape, ragged_dim=ragged_dim)
73
74
75@dataclass

Callers 4

_canonicalize_storageFunction · 0.85
__post_init__Method · 0.85
wrap_torch_tensorFunction · 0.85
convert_layoutFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected