MCPcopy
hub / github.com/tinygrad/tinygrad / DType

Class DType

tinygrad/dtype.py:57–100  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

55
56@dataclass(frozen=True, eq=False)
57class DType(metaclass=DTypeMetaClass):
58 priority: int # this determines when things get upcasted
59 bitsize: int
60 name: str
61 fmt: FmtStr|None
62 count: int
63 _scalar: DType|None
64 @property
65 def itemsize(self) -> int: return (self.bitsize + 7) // 8
66 @staticmethod
67 def new(priority:int, bitsize:int, name:str, fmt:FmtStr|None): return DType(priority, bitsize, name, fmt, 1, None)
68 def __reduce__(self): return type(self), tuple(getattr(self, f.name) for f in fields(self))
69 def __repr__(self): return f"dtypes.{INVERSE_DTYPES_DICT[self.scalar().name]}"+(f".vec({self.count})" if self.count != 1 else "")
70 def __lt__(self, o:DType): return (self.priority, self.bitsize, self.name, self.fmt, self.count) < (o.priority, o.bitsize, o.name, o.fmt, o.count)
71 @property
72 def base(self): return self
73 @property
74 def vcount(self): return self.count
75 @functools.cache # pylint: disable=method-cache-max-size-none
76 def vec(self, sz:int) -> DType:
77 assert self.count == 1, f"can't vectorize {self} with size {sz}"
78 if sz == 1 or self == dtypes.void: return self # void doesn't vectorize, and sz=1 is scalar
79 return DType(self.priority, self.bitsize*sz, f"{INVERSE_DTYPES_DICT[self.name]}{sz}", None, sz, self)
80 def ptr(self, size=-1, addrspace=AddrSpace.GLOBAL) -> PtrDType:
81 return PtrDType(self.priority, self.bitsize, self.name, self.fmt, self.count, None, self, addrspace, 1, size)
82 def scalar(self) -> DType: return self._scalar if self._scalar is not None else self
83 def nbytes(self) -> int: raise RuntimeError("only ptr types have nbytes")
84 @functools.cached_property
85 def min(self):
86 if dtypes.is_int(self): return 0 if dtypes.is_unsigned(self) else -2**(self.scalar().bitsize-1)
87 return -float("inf") if dtypes.is_float(self) else False
88 @functools.cached_property
89 def max(self):
90 if dtypes.is_int(self): return 2**(self.scalar().bitsize)-1+self.min
91 return float("inf") if dtypes.is_float(self) else True
92 def const(self, val: tuple[ConstType, ...]|ConstType):
93 if isinstance(val, tuple):
94 assert len(val) == self.count, f"mismatch {val} {self}"
95 return tuple(map(self.const, val))
96 if isinstance(val, InvalidType): return val
97 # NOTE: float('nan') != float('nan'), so we canonicalize here
98 if isinstance(val, float) and math.isnan(val): val = math.nan
99 # int is the default. wrap floats in ConstFloat to distinguish -0.0 from 0.0 in cache
100 return ConstFloat(float(val)) if dtypes.is_float(self) else bool(val) if dtypes.is_bool(self) else int(val)
101
102@dataclass(frozen=True, eq=False)
103class PtrDType(DType):

Callers 2

newMethod · 0.85
vecMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…