| 13 | from tinygrad.renderer import Renderer |
| 14 | |
| 15 | class Scheduler: |
| 16 | def __init__(self, ast:UOp, ren:Renderer): |
| 17 | self.ast, self.ren = ast, ren |
| 18 | self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False |
| 19 | self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else [] |
| 20 | self.opt_range = count(start=max([x.arg[0] for x in self.rngs], default=0)+1) |
| 21 | |
| 22 | @property |
| 23 | def rngs(self): |
| 24 | # always in order by axistype |
| 25 | return sorted([u for u in self.ast.backward_slice if u.op is Ops.RANGE and u.vmax > 0], key=lambda x: (axis_to_pos[x.arg[-1]],) + x.arg[0:-1]) |
| 26 | @property |
| 27 | def shape_len(self) -> int: return len(self.rngs) |
| 28 | @property |
| 29 | def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs] |
| 30 | @property |
| 31 | def axis_types(self) -> list[AxisType]: return [x.arg[-1] for x in self.rngs] |
| 32 | |
| 33 | # strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2'] |
| 34 | def shape_str(self) -> list[str]: |
| 35 | ret: list[str] = [] |
| 36 | cnt: dict[AxisType, int] = {} |
| 37 | for x in self.axis_types: |
| 38 | cnt[x] = (cnt[x] + 1) if x in cnt else 0 |
| 39 | ret.append(f"{axis_letters[x]}{cnt[x]}") |
| 40 | return ret |
| 41 | def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms]) |
| 42 | |
| 43 | def copy(self) -> Scheduler: |
| 44 | ret = Scheduler(self.ast, self.ren) |
| 45 | ret.dont_use_locals = self.dont_use_locals |
| 46 | ret.applied_opts = self.applied_opts[:] |
| 47 | if hasattr(self, 'tensor_core'): ret.tensor_core = self.tensor_core |
| 48 | return ret |
| 49 | |
| 50 | kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int) |
| 51 | def get_optimized_ast(self, name_override:str|None=None) -> UOp: |
| 52 | if name_override is not None: name = name_override |
| 53 | else: |
| 54 | k_type = "r" if self.reduceop is not None else "E" |
| 55 | special_uops = sorted([x for x in self.ast.toposort() if x.op is Ops.SPECIAL], key=lambda x: x.arg) |
| 56 | special_ops = [colored(str(x.vmax+1), "blue" if x.arg[0] == "g" else "cyan") for x in special_uops] |
| 57 | name = k_type + colored('_', 'BLACK').join(['']+special_ops+[colored(x.src[0].render(), color) for x,color in zip(self.rngs, self.colors())]) |
| 58 | Scheduler.kernel_cnt[(function_name := to_function_name(name))] += 1 |
| 59 | num = f"n{Scheduler.kernel_cnt[function_name]-1}" if Scheduler.kernel_cnt[function_name] > 1 else "" |
| 60 | name += colored(num, 'BLACK') |
| 61 | self.ast = graph_rewrite(self.ast, pm_flatten_range, name="flatten range") |
| 62 | return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1) |
| 63 | |
| 64 | def _output_rngs(self) -> list[UOp]: |
| 65 | return flatten([[r for r in UOp.sink(*s.src[1:]).ranges if r.arg[-1] != AxisType.REDUCE] for s in self.ast.src if s.op is Ops.END]) |
| 66 | def _globalizable_rngs(self) -> list[UOp]: |
| 67 | ret = [r for r in self._output_rngs() if r.arg[-1] == AxisType.LOOP] |
| 68 | # exclude any output ranges from global that don't appear in all BUFFERIZE |
| 69 | for x in self.ast.toposort(): |
| 70 | if x.op is Ops.STAGE: |
| 71 | ret = [r for r in ret if r in x.ranges] |
| 72 | return ret |
no outgoing calls
searching dependent graphs…