MCPcopy
hub / github.com/tinygrad/tinygrad / _render

Method _render

tinygrad/renderer/cstyle.py:159–214  ·  view source on GitHub ↗
(self, uops:list[UOp])

Source from the content-addressed store, hash-verified

157
158 def __getitem__(self, key): return self.r[key] # hacky helper
159 def _render(self, uops:list[UOp]) -> tuple[str, list[str], list[tuple[str,tuple[DType,bool]]]]:
160 r: dict[UOp, str] = {}
161 self.r = r
162
163 child_count = Counter(v for ru in uops for v in ru.src)
164 # find which PARAMs are stored to with a single toposort
165 writable_params = {u for u in UOp.sink(*[u.src[0] for u in uops if u.op is Ops.STORE]).toposort(lambda u: u.op != Ops.END) if u.op is Ops.PARAM}
166 bufs: dict[UOp, tuple[str, tuple[DType, bool]]] = {}
167 kernel = []
168 depth = 1
169 c: defaultdict[str, int] = defaultdict(int)
170 name = "test"
171 for u in uops:
172 if u.op in {Ops.NOOP, Ops.GROUP}: continue
173 if u.op is Ops.AFTER:
174 r[u] = r[u.src[0]]
175 continue
176 if u.op is Ops.SINK:
177 if u.arg is not None: name = u.arg.function_name
178 continue
179 if u.op in (Ops.PARAM, Ops.DEFINE_VAR):
180 if u.op is not Ops.PARAM: r[u] = u.arg[0]
181 elif isinstance(u.dtype, ImageDType): r[u] = f"data{u.arg}_{u.dtype.shape[0]}x{u.dtype.shape[1]}"
182 else: r[u] = f"data{u.arg}_{sz}" if (sz:=u.ptrdtype.size) > 0 else f"data{u.arg}"
183 bufs[u] = (r[u], (u.dtype, u in writable_params))
184 continue
185
186 # naming
187 prefix = None
188 if u.op is Ops.SPECIAL: r[u] = u.arg
189 elif u.op is Ops.RANGE: r[u] = f"{axis_letters[u.arg[-1]]}idx"+range_str(u)
190 else:
191 prefix = {Ops.WMMA: "wmma", Ops.DEFINE_LOCAL: "temp", Ops.CONST: "const",
192 Ops.CAST: "cast", Ops.BITCAST: "cast", Ops.GEP: "gep", Ops.STACK: "cast",
193 Ops.INDEX: "bidx", Ops.DEFINE_REG: "acc", Ops.LOAD: "val"}.get(u.op, "alu")
194 r[u] = f"{prefix}{c[prefix]}"
195
196 l = cast(str, self.string_rewrite.rewrite(u, ctx=self))
197 assert l is not None, f"failed to render {u.op} {u.dtype} {[(x.op,x.dtype) for x in u.src]} {u.arg}"
198
199 if u.op in {Ops.ENDIF, Ops.END}: depth -= 1
200 if (u.op is not Ops.CAST or u.dtype.vcount == 1) and (u.op in {Ops.CONST, Ops.GEP, Ops.INDEX, Ops.CUSTOMI} or \
201 (u.op is Ops.LOAD and u.src[0].ptrdtype.addrspace == AddrSpace.REG) or \
202 (u.op is Ops.CAST and isinstance(u.dtype, PtrDType)) or \
203 (u.op in {Ops.STACK, *(GroupOp.ALU-{Ops.WHERE}), Ops.CAST, Ops.BITCAST} and child_count[u] == 1 and not getenv("EXPAND_SSA"))):
204 r[u] = l
205 else:
206 if u.op not in {Ops.RANGE, Ops.DEFINE_LOCAL, Ops.STORE, Ops.DEFINE_REG} and u.dtype != dtypes.void:
207 l = f"{self.render_dtype(u.dtype)} {r[u]} = {l}" + (";" if u.op is not Ops.SPECIAL else "")
208 kernel.append(" "*depth + l)
209 if prefix: c[prefix] += 1 # if it was used, increment
210 if u.op in {Ops.IF, Ops.RANGE}: depth += 1
211 del self.r
212
213 # NOTE: this relies on bufs dict preserving order
214 return (name, kernel, list(bufs.values()))
215 def render(self, uops:list[UOp]) -> str: return self.render_kernel(*self._render(uops), uops)
216

Callers 1

renderMethod · 0.95

Calls 9

render_dtypeMethod · 0.95
range_strFunction · 0.90
getenvFunction · 0.90
castFunction · 0.85
toposortMethod · 0.80
appendMethod · 0.80
sinkMethod · 0.45
getMethod · 0.45
rewriteMethod · 0.45

Tested by

no test coverage detected