MCPcopy
hub / github.com/tinygrad/tinygrad / UOp

Class UOp

tinygrad/uop/ops.py:128–1035  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

126# NOTE: this should be frozen, but frozen is slower
127@dataclass(eq=False, slots=True)
128class UOp(OpMixin, metaclass=UOpMetaClass):
129 op:Ops
130 dtype:DType = dtypes.void
131 src:tuple[UOp, ...] = tuple()
132 arg:Any = None
133 tag:Any = None
134 def __del__(self):
135 if Ops is not None and self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
136 try: del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg, self.tag)]
137 except AttributeError: pass
138 def __reduce__(self):
139 args = [self.op, self.dtype, self.src, self.arg, self.tag, self.metadata]
140 if self.op is Ops.BUFFER and self.realized is not None: args.append(self.realized)
141 return UOp, tuple(args)
142 def replace(self, **kwargs) -> UOp:
143 new_args = (kwargs.pop("op", self.op), kwargs.pop("dtype", self.dtype), kwargs.pop("src", self.src),
144 kwargs.pop("arg", self.arg), kwargs.pop("tag", self.tag))
145 assert len(kwargs) == 0, f"unused kwargs in replace {list(kwargs)}"
146 if (self.op, self.dtype, self.src, self.arg, self.tag) == new_args: return self
147 return UOp(*new_args)
148 def rtag(self, tag=True): return self.replace(tag=tag)
149 @recursive_property
150 def key(self) -> bytes:
151 return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
152 def __repr__(self):
153 from tinygrad.uop.render import pretty_print
154 return pretty_print(self)
155 def argstr(self):
156 if self.op is Ops.REDUCE: return f'({", ".join(map(str, self.arg))})'
157 return f"ConstFloat({float.__repr__(self.arg)})" if isinstance(self.arg, ConstFloat) else repr(self.arg)
158 def tagstr(self): return f", tag={self.tag}" if self.tag is not None else ""
159
160 def f(self, op, **kwargs): return UOp(op, dtype=kwargs.pop("dtype", self.dtype), src=(self,), **kwargs)
161
162 @functools.cached_property
163 def backward_slice(self:UOp) -> dict[UOp, None]:
164 res: dict[UOp, None] = self.toposort()
165 res.pop(self)
166 return res
167
168 @property
169 def backward_slice_with_self(self:UOp) -> dict[UOp, None]: return {self:None, **self.backward_slice}
170 def op_in_backward_slice_with_self(self, *ops:Ops) -> bool:
171 # Check self first, then iterate backward_slice (avoids creating intermediate dict)
172 return self.op in ops or any(x.op in ops for x in self.backward_slice)
173
174 def toposort(self, gate:Callable|None=None, enter_calls=True) -> dict[UOp, None]:
175 cache: dict[UOp, None] = {}
176 stack: list[tuple[UOp, bool]] = [(self, False)] # each stack entry is (node, visited_flag)
177 while stack:
178 node, visited = stack.pop()
179 if node in cache: continue
180 if not visited:
181 if gate is None or gate(node):
182 stack.append((node, True)) # push node back on stack to process after its srcs
183 for s in reversed(node.src if enter_calls or node.op not in {Ops.CALL, Ops.FUNCTION} else node.src[1:]):
184 stack.append((s, False)) # push srcs on the stack
185 else: cache[node] = None # second time i'm seeing this node, add it to returned toposort

Callers 15

_make_buffer_viewFunction · 0.90
function.pyFile · 0.90
call_gradientFunction · 0.90
gradient.pyFile · 0.90
compute_gradientFunction · 0.90
decode_hevc_frameMethod · 0.90
MetalRendererClass · 0.90
HIPRendererClass · 0.90
packed_storeFunction · 0.90
packed_loadFunction · 0.90

Calls 1

countMethod · 0.45

Tested by 15

_verify_indices_z3Method · 0.72
test_global_prod_maxMethod · 0.72
test_where_castMethod · 0.72
invalid_fxnMethod · 0.72
test_colored_labelMethod · 0.72
test_inf_loopMethod · 0.72
test_gc_uop_in_argMethod · 0.72