MCPcopy
hub / github.com/tinygrad/tinygrad / from_uops

Method from_uops

tinygrad/renderer/__init__.py:21–60  ·  view source on GitHub ↗
(uops:tuple[UOp, ...], ignore_indexing=False)

Source from the content-addressed store, hash-verified

19 def simplify(self): return Estimates(ssimplify(self.ops), ssimplify(self.lds), ssimplify(self.mem))
20 @staticmethod
21 def from_uops(uops:tuple[UOp, ...], ignore_indexing=False) -> Estimates:
22 flops: sint = 0
23 lds: sint = 0
24 mem: dict[tuple[UOp, Ops], sint] = {}
25 mults: sint = 1
26 mult_stack: list[sint] = []
27 dont_count: set[UOp] = set()
28 if ignore_indexing:
29 def range_gate(x): return x.op is not Ops.RANGE
30 for u in uops:
31 if u.op in {Ops.LOAD, Ops.STORE}:
32 # if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER
33 dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate))
34 # TODO: is this correct? this all needs to be cleaned up
35 if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort())
36 elif u.op is Ops.IF:
37 dont_count = dont_count.union(u.src[0].toposort())
38 for u in uops:
39 if u.op in {Ops.LOAD, Ops.STORE}:
40 buf = u
41 while len(buf.src): buf = buf.src[0]
42 if buf.op is Ops.PARAM:
43 # u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul)
44 accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults
45 mem[(buf, u.op)] = smin(accessed, buf.ptrdtype.nbytes()) if buf.ptrdtype.size != -1 else accessed
46 if u.op is Ops.RANGE:
47 mult_stack.append(mults)
48 mults *= cast(sint, u.src[0].ssimplify())
49 # SPECIAL are already counted in mults
50 mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults
51 elif u.op is Ops.END: mults = mult_stack.pop(-1)
52 elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these
53 elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1
54 elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
55 lds += u.dtype.itemsize * mults
56 elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG):
57 lds += u.src[1].dtype.itemsize * mults
58 elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count
59 elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults
60 return Estimates(flops, lds, sum(mem.values()))
61
62class Renderer:
63 target: Target

Callers 2

do_estimatesFunction · 0.80
flops_memFunction · 0.80

Calls 12

sminFunction · 0.90
prodFunction · 0.90
castFunction · 0.85
EstimatesClass · 0.85
toposortMethod · 0.80
appendMethod · 0.80
ssimplifyMethod · 0.80
substituteMethod · 0.80
sinkMethod · 0.45
getMethod · 0.45
nbytesMethod · 0.45
const_likeMethod · 0.45

Tested by 1

flops_memFunction · 0.64