| 9 | |
| 10 | @dataclass(frozen=True) |
| 11 | class Estimates: |
| 12 | # number of FLOPS used in the Kernel |
| 13 | ops:sint = 0 |
| 14 | # bytes accessed in loads and stores |
| 15 | lds:sint = 0 |
| 16 | # total bytes accessed, counting only once for bytes that are accessed multiple times |
| 17 | mem:sint = 0 |
| 18 | def __add__(self, o:Estimates): return Estimates(self.ops + o.ops, self.lds + o.lds, self.mem + o.mem) |
| 19 | def simplify(self): return Estimates(ssimplify(self.ops), ssimplify(self.lds), ssimplify(self.mem)) |
| 20 | @staticmethod |
| 21 | def from_uops(uops:tuple[UOp, ...], ignore_indexing=False) -> Estimates: |
| 22 | flops: sint = 0 |
| 23 | lds: sint = 0 |
| 24 | mem: dict[tuple[UOp, Ops], sint] = {} |
| 25 | mults: sint = 1 |
| 26 | mult_stack: list[sint] = [] |
| 27 | dont_count: set[UOp] = set() |
| 28 | if ignore_indexing: |
| 29 | def range_gate(x): return x.op is not Ops.RANGE |
| 30 | for u in uops: |
| 31 | if u.op in {Ops.LOAD, Ops.STORE}: |
| 32 | # if u.src[0] is INDEX, we have to include the buffer since it might be an AFTER |
| 33 | dont_count = dont_count.union((UOp.sink(*u.src[0].src[1:]) if u.src[0].op is Ops.INDEX else u.src[0]).toposort(range_gate)) |
| 34 | # TODO: is this correct? this all needs to be cleaned up |
| 35 | if len(u.src) > 2: dont_count = dont_count.union(u.src[2].toposort()) |
| 36 | elif u.op is Ops.IF: |
| 37 | dont_count = dont_count.union(u.src[0].toposort()) |
| 38 | for u in uops: |
| 39 | if u.op in {Ops.LOAD, Ops.STORE}: |
| 40 | buf = u |
| 41 | while len(buf.src): buf = buf.src[0] |
| 42 | if buf.op is Ops.PARAM: |
| 43 | # u.src[0] is INDEX, cap at buffer size for re-reads (e.g. matmul) |
| 44 | accessed = mem.get((buf, u.op), 0) + u.src[0].dtype.base.itemsize * mults |
| 45 | mem[(buf, u.op)] = smin(accessed, buf.ptrdtype.nbytes()) if buf.ptrdtype.size != -1 else accessed |
| 46 | if u.op is Ops.RANGE: |
| 47 | mult_stack.append(mults) |
| 48 | mults *= cast(sint, u.src[0].ssimplify()) |
| 49 | # SPECIAL are already counted in mults |
| 50 | mults = mults.substitute({x:x.const_like(0) for x in mults.toposort() if x.op is Ops.SPECIAL}) if isinstance(mults, UOp) else mults |
| 51 | elif u.op is Ops.END: mults = mult_stack.pop(-1) |
| 52 | elif u.op is Ops.SPECIAL: mults *= cast(sint, u.src[0].ssimplify()) # NOTE: we don't push to the mult_stack here, you can't end these |
| 53 | elif u.op is Ops.DEFINE_VAR and u.arg[0] == 'core_id': mults *= u.arg[2] + 1 |
| 54 | elif u.op is Ops.LOAD and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): |
| 55 | lds += u.dtype.itemsize * mults |
| 56 | elif u.op is Ops.STORE and (not isinstance(u.src[0].dtype, PtrDType) or u.src[0].dtype.addrspace != AddrSpace.REG): |
| 57 | lds += u.src[1].dtype.itemsize * mults |
| 58 | elif u.op in GroupOp.ALU and u not in dont_count: flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count |
| 59 | elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults |
| 60 | return Estimates(flops, lds, sum(mem.values())) |
| 61 | |
| 62 | class Renderer: |
| 63 | target: Target |
no outgoing calls
searching dependent graphs…