| 55 | |
| 56 | @dataclass(frozen=True, eq=False) |
| 57 | class DType(metaclass=DTypeMetaClass): |
| 58 | priority: int # this determines when things get upcasted |
| 59 | bitsize: int |
| 60 | name: str |
| 61 | fmt: FmtStr|None |
| 62 | count: int |
| 63 | _scalar: DType|None |
| 64 | @property |
| 65 | def itemsize(self) -> int: return (self.bitsize + 7) // 8 |
| 66 | @staticmethod |
| 67 | def new(priority:int, bitsize:int, name:str, fmt:FmtStr|None): return DType(priority, bitsize, name, fmt, 1, None) |
| 68 | def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self)) |
| 69 | def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count != 1 else "") |
| 70 | def __lt__(self, o:DType): return (self.priority, self.bitsize, self.name, self.fmt, self.count) < (o.priority, o.bitsize, o.name, o.fmt, o.count) |
| 71 | @property |
| 72 | def base(self): return self |
| 73 | @property |
| 74 | def vcount(self): return self.count |
| 75 | @functools.cache # pylint: disable=method-cache-max-size-none |
| 76 | def vec(self, sz:int) -> DType: |
| 77 | assert self.count == 1, f"can't vectorize {self} with size {sz}" |
| 78 | if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar |
| 79 | return DType(self.priority, self.bitsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self) |
| 80 | def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType: |
| 81 | return PtrDType(self.priority, self.bitsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size) |
| 82 | def scalar(self) -> DType: return self._scalar if self._scalar is not None else self |
| 83 | def nbytes(self) -> int: raise RuntimeError("only ptr types have nbytes") |
| 84 | @functools.cached_property |
| 85 | def min(self): |
| 86 | if dtypes.is_int(self): return 0 if dtypes.is_unsigned(self) else -2**(self.scalar().bitsize-1) |
| 87 | return -float("inf") if dtypes.is_float(self) else False |
| 88 | @functools.cached_property |
| 89 | def max(self): |
| 90 | if dtypes.is_int(self): return 2**(self.scalar().bitsize)-1+self.min |
| 91 | return float("inf") if dtypes.is_float(self) else True |
| 92 | def const(self, val: tuple[ConstType, ...]|ConstType): |
| 93 | if isinstance(val, tuple): |
| 94 | assert len(val) == self.count, f"mismatch {val} {self}" |
| 95 | return tuple(map(self.const, val)) |
| 96 | if isinstance(val, InvalidType): return val |
| 97 | # NOTE: float('nan') != float('nan'), so we canonicalize here |
| 98 | if isinstance(val, float) and math.isnan(val): val = math.nan |
| 99 | # int is the default. wrap floats in ConstFloat to distinguish -0.0 from 0.0 in cache |
| 100 | return ConstFloat(float(val)) if dtypes.is_float(self) else bool(val) if dtypes.is_bool(self) else int(val) |
| 101 | |
| 102 | @dataclass(frozen=True, eq=False) |
| 103 | class PtrDType(DType): |