| 126 | # NOTE: this should be frozen, but frozen is slower |
| 127 | @dataclass(eq=False, slots=True) |
| 128 | class UOp(OpMixin, metaclass=UOpMetaClass): |
| 129 | op:Ops |
| 130 | dtype:DType = dtypes.void |
| 131 | src:tuple[UOp, ...] = tuple() |
| 132 | arg:Any = None |
| 133 | tag:Any = None |
| 134 | def __del__(self): |
| 135 | if Ops is not None and self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1) |
| 136 | try: del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg, self.tag)] |
| 137 | except AttributeError: pass |
| 138 | def __reduce__(self): |
| 139 | args = [self.op, self.dtype, self.src, self.arg, self.tag, self.metadata] |
| 140 | if self.op is Ops.BUFFER and self.realized is not None: args.append(self.realized) |
| 141 | return UOp, tuple(args) |
| 142 | def replace(self, **kwargs) -> UOp: |
| 143 | new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src), |
| 144 | kwargs.pop("arg", self.arg), kwargs.pop("tag", self.tag)) |
| 145 | assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}" |
| 146 | if (self.op, self.dtype, self.src, self.arg, self.tag) == new_args: return self |
| 147 | return UOp(*new_args) |
| 148 | def rtag(self, tag=True): return self.replace(tag=tag) |
| 149 | @recursive_property |
| 150 | def key(self) -> bytes: |
| 151 | return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest() |
| 152 | def __repr__(self): |
| 153 | from tinygrad.uop.render import pretty_print |
| 154 | return pretty_print(self) |
| 155 | def argstr(self): |
| 156 | if self.op is Ops.REDUCE: return f'({", ".join(map(str, self.arg))})' |
| 157 | return f"ConstFloat({float.__repr__(self.arg)})" if isinstance(self.arg, ConstFloat) else repr(self.arg) |
| 158 | def tagstr(self): return f", tag={self.tag}" if self.tag is not None else "" |
| 159 | |
| 160 | def f(self, op, **kwargs): return UOp(op, dtype=kwargs.pop("dtype", self.dtype), src=(self,), **kwargs) |
| 161 | |
| 162 | @functools.cached_property |
| 163 | def backward_slice(self:UOp) -> dict[UOp, None]: |
| 164 | res: dict[UOp, None] = self.toposort() |
| 165 | res.pop(self) |
| 166 | return res |
| 167 | |
| 168 | @property |
| 169 | def backward_slice_with_self(self:UOp) -> dict[UOp, None]: return {self:None, **self.backward_slice} |
| 170 | def op_in_backward_slice_with_self(self, *ops:Ops) -> bool: |
| 171 | # Check self first, then iterate backward_slice (avoids creating intermediate dict) |
| 172 | return self.op in ops or any(x.op in ops for x in self.backward_slice) |
| 173 | |
| 174 | def toposort(self, gate:Callable|None=None, enter_calls=True) -> dict[UOp, None]: |
| 175 | cache: dict[UOp, None] = {} |
| 176 | stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag) |
| 177 | while stack: |
| 178 | node, visited = stack.pop() |
| 179 | if node in cache: continue |
| 180 | if not visited: |
| 181 | if gate is None or gate(node): |
| 182 | stack.append((node, True)) # push node back on stack to process after its srcs |
| 183 | for s in reversed(node.src if enter_calls or node.op not in {Ops.CALL, Ops.FUNCTION} else node.src[1:]): |
| 184 | stack.append((s, False)) # push srcs on the stack |
| 185 | else: cache[node] = None # second time i'm seeing this node, add it to returned toposort |