| 157 | |
| 158 | def __getitem__(self, key): return self.r[key] # hacky helper |
| 159 | def _render(self, uops:list[UOp]) -> tuple[str, list[str], list[tuple[str,tuple[DType,bool]]]]: |
| 160 | r: dict[UOp, str] = {} |
| 161 | self.r = r |
| 162 | |
| 163 | child_count = Counter(v for ru in uops for v in ru.src) |
| 164 | # find which PARAMs are stored to with a single toposort |
| 165 | writable_params = {u for u in UOp.sink(*[u.src[0] for u in uops if u.op is Ops.STORE]).toposort(lambda u: u.op != Ops.END) if u.op is Ops.PARAM} |
| 166 | bufs: dict[UOp, tuple[str, tuple[DType, bool]]] = {} |
| 167 | kernel = [] |
| 168 | depth = 1 |
| 169 | c: defaultdict[str, int] = defaultdict(int) |
| 170 | name = "test" |
| 171 | for u in uops: |
| 172 | if u.op in {Ops.NOOP, Ops.GROUP}: continue |
| 173 | if u.op is Ops.AFTER: |
| 174 | r[u] = r[u.src[0]] |
| 175 | continue |
| 176 | if u.op is Ops.SINK: |
| 177 | if u.arg is not None: name = u.arg.function_name |
| 178 | continue |
| 179 | if u.op in (Ops.PARAM, Ops.DEFINE_VAR): |
| 180 | if u.op is not Ops.PARAM: r[u] = u.arg[0] |
| 181 | elif isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg}_{u.dtype.shape[0]}x{u.dtype.shape[1]}" |
| 182 | else: r[u] = f"data{u.arg}_{sz}" if (sz:=u.ptrdtype.size) > 0 else f"data{u.arg}" |
| 183 | bufs[u] = (r[u], (u.dtype, u in writable_params)) |
| 184 | continue |
| 185 | |
| 186 | # naming |
| 187 | prefix = None |
| 188 | if u.op is Ops.SPECIAL: r[u] = u.arg |
| 189 | elif u.op is Ops.RANGE: r[u] = f"{axis_letters[u.arg[-1]]}idx"+range_str(u) |
| 190 | else: |
| 191 | prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const", |
| 192 | Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast", |
| 193 | Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu") |
| 194 | r[u] = f"{prefix}{c[prefix]}" |
| 195 | |
| 196 | l = cast(str, self.string_rewrite.rewrite(u, ctx=self)) |
| 197 | assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}" |
| 198 | |
| 199 | if u.op in {Ops.ENDIF, Ops.END}: depth -= 1 |
| 200 | if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \ |
| 201 | (u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \ |
| 202 | (u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \ |
| 203 | (u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))): |
| 204 | r[u] = l |
| 205 | else: |
| 206 | if u.op not in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} and u.dtype != dtypes.void: |
| 207 | l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "") |
| 208 | kernel.append(" "*depth + l) |
| 209 | if prefix: c[prefix] += 1 # if it was used, increment |
| 210 | if u.op in {Ops.IF, Ops.RANGE}: depth += 1 |
| 211 | del self.r |
| 212 | |
| 213 | # NOTE: this relies on bufs dict preserving order |
| 214 | return (name, kernel, list(bufs.values())) |
| 215 | def render(self, uops:list[UOp]) -> str: return self.render_kernel(*self._render(uops), uops) |
| 216 | |