| 99 | |
| 100 | @dataclass |
| 101 | class Tensor: |
| 102 | storage: Storage | torch.Tensor |
| 103 | dtype: IntegerType | FloatType | torch.dtype = None |
| 104 | shape: list[int] | None = None |
| 105 | shape_max: list[int] | None = None |
| 106 | |
| 107 | def __post_init__(self): |
| 108 | # set storage |
| 109 | if isinstance(self.storage, torch.Tensor): |
| 110 | self.storage = Storage(self.storage) |
| 111 | # initialize dtype |
| 112 | if self.dtype is None: |
| 113 | self.dtype = self.storage.data.dtype |
| 114 | if bitwidth(self.dtype) < 8 and self.shape is None: |
| 115 | raise ValueError("shape must be provided for sub-byte types") |
| 116 | # initialize shape |
| 117 | if self.shape is None: |
| 118 | self.shape = list(self.storage.data.shape) |
| 119 | # validate shape: all elements must be `int` or numel-1 `torch.Tensor` |
| 120 | is_int = lambda s: isinstance(s, int) |
| 121 | is_item = lambda s: hasattr(s, "numel") and s.numel() == 1 |
| 122 | assert all(map(lambda s: is_int(s) or is_item(s), self.shape)) |
| 123 | # initialize shape_max |
| 124 | if self.shape_max is None: |
| 125 | self.shape_max = [None] * len(self.shape) |
| 126 | for i, (s, smax) in enumerate(zip(self.shape, self.shape_max)): |
| 127 | if smax is not None and not is_int(smax): |
| 128 | raise ValueError(f"shape_max[{i}] must be `int` or `None`; got {type(smax)}") |
| 129 | if smax is None: |
| 130 | self.shape_max[i] = s |
| 131 | # validate shape_max: all elements must be `int` |
| 132 | assert all(map(is_int, self.shape_max)) |
| 133 | |
| 134 | # torch compatibility layer |
| 135 | @property |
| 136 | def ndim(self): |
| 137 | return len(self.shape) |
| 138 | |
| 139 | @property |
| 140 | def device(self): |
| 141 | return self.storage.device |
| 142 | |
| 143 | def stride(self, i=None): |
| 144 | return self.storage.data.stride() if i is None else self.storage.data.stride(i) |
| 145 | |
| 146 | def data_ptr(self): |
| 147 | return self.storage.data.data_ptr() |
| 148 | |
| 149 | def numel(self): |
| 150 | return self.storage.data.numel() |
| 151 | |
| 152 | def element_size(self): |
| 153 | return bitwidth(self.dtype) // 8 |
| 154 | |
| 155 | @property |
| 156 | def data(self): |
| 157 | t = self.storage |
| 158 | return t.data if isinstance(t, Storage) else t |
no outgoing calls
no test coverage detected