| 16 | |
| 17 | @dataclass |
| 18 | class Storage: |
| 19 | data: torch.Tensor |
| 20 | layout: Layout = None |
| 21 | |
| 22 | def __post_init__(self): |
| 23 | assert isinstance(self.data, torch.Tensor) |
| 24 | if self.layout is None: |
| 25 | self.layout = StridedLayout(self.data.shape) |
| 26 | |
| 27 | @property |
| 28 | def device(self): |
| 29 | return self.data.device |
| 30 | |
| 31 | def is_tma_compliant(self): |
| 32 | # TMAs didn't exist until Hopper |
| 33 | if not cuda_capability_geq(9, 0): |
| 34 | return False |
| 35 | # TMAs only exist for 2D, 3D, 5D inputs |
| 36 | if len(self.data.shape) not in [2, 3, 5]: |
| 37 | return False |
| 38 | # TMAs need at most one stride equal to 1 |
| 39 | # and all other strides divisble by 16 |
| 40 | strides = list(self.data.stride()) |
| 41 | try: |
| 42 | major_dim = strides.index(1) |
| 43 | except ValueError: |
| 44 | major_dim = -1 |
| 45 | ndim = self.data.ndim |
| 46 | bitwidth = 4 if self.data.dtype == torch.uint8 else self.data.element_size() * 8 |
| 47 | compliant = [strides[i] * bitwidth % 128 == 0 for i in range(ndim) if i != major_dim] |
| 48 | return all(compliant) |
| 49 | |
| 50 | def make_dense_tma(self, block_shape, transpose=False): |
| 51 | strides = list(self.data.stride()) |
| 52 | shape = list(self.data.shape) |
| 53 | transpose = self.data.stride()[-1] != 1 |
| 54 | if transpose: |
| 55 | block_shape = block_shape[:-2] + [block_shape[-1], block_shape[-2]] |
| 56 | shape = shape[:-2] + [shape[-1], shape[-2]] |
| 57 | strides = strides[:-2] + [strides[-1], strides[-2]] |
| 58 | if self.data.dtype == torch.uint8 and self.layout.name == "BLACKWELL_VALUE": |
| 59 | indx = strides.index(1) |
| 60 | block_shape[indx] = block_shape[indx] // 2 |
| 61 | if shape[-1] % 128 != 0: |
| 62 | raise ValueError("inner shape need to be multiple of 128 for " |
| 63 | "mxfp4 (CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B) TMAs.") |
| 64 | block_shape = self.layout.swizzle_block_shape(block_shape) |
| 65 | return TensorDescriptor(self.data, shape, strides, block_shape) |
| 66 | |
| 67 | def make_tma(self, block_shape, mode, transpose=False): |
| 68 | if mode in ["dense", "gather", "scatter"]: |
| 69 | return self.make_dense_tma(block_shape, transpose) |
| 70 | assert mode == "ragged" |
| 71 | ragged_dim = len(self.data.shape) - 2 |
| 72 | return create_ragged_descriptor(self.data, block_shape, ragged_dim=ragged_dim) |
| 73 | |
| 74 | |
| 75 | @dataclass |
no outgoing calls
no test coverage detected