MCPcopy
hub / github.com/tinygrad/tinygrad / Scheduler

Class Scheduler

tinygrad/codegen/opt/postrange.py:15–329  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

13from tinygrad.renderer import Renderer
14
15class Scheduler:
16 def __init__(self, ast:UOp, ren:Renderer):
17 self.ast, self.ren = ast, ren
18 self.dont_use_locals = self.ast.arg.dont_use_locals if self.ast.arg is not None else False
19 self.applied_opts = list(self.ast.arg.applied_opts) if self.ast.arg is not None else []
20 self.opt_range = count(start=max([x.arg[0] for x in self.rngs], default=0)+1)
21
22 @property
23 def rngs(self):
24 # always in order by axistype
25 return sorted([u for u in self.ast.backward_slice if u.op is Ops.RANGE and u.vmax > 0], key=lambda x: (axis_to_pos[x.arg[-1]],) + x.arg[0:-1])
26 @property
27 def shape_len(self) -> int: return len(self.rngs)
28 @property
29 def full_shape(self): return [ssimplify(x.src[0]) for x in self.rngs]
30 @property
31 def axis_types(self) -> list[AxisType]: return [x.arg[-1] for x in self.rngs]
32
33 # strings like ['g0', 'g1', 'l0', 'l1', 'l2', 'l3', 'l4', 'l5', 'R0', 'r0', 'r1', 'r2', 'u0', 'u1', 'u2']
34 def shape_str(self) -> list[str]:
35 ret: list[str] = []
36 cnt: dict[AxisType, int] = {}
37 for x in self.axis_types:
38 cnt[x] = (cnt[x] + 1) if x in cnt else 0
39 ret.append(f"{axis_letters[x]}{cnt[x]}")
40 return ret
41 def shape_str_to_axis(self, nms:list[str]) -> tuple[int, ...]: return tuple([self.shape_str().index(x) for x in nms])
42
43 def copy(self) -> Scheduler:
44 ret = Scheduler(self.ast, self.ren)
45 ret.dont_use_locals = self.dont_use_locals
46 ret.applied_opts = self.applied_opts[:]
47 if hasattr(self, 'tensor_core'): ret.tensor_core = self.tensor_core
48 return ret
49
50 kernel_cnt: Final[defaultdict[str, int]] = defaultdict(int)
51 def get_optimized_ast(self, name_override:str|None=None) -> UOp:
52 if name_override is not None: name = name_override
53 else:
54 k_type = "r" if self.reduceop is not None else "E"
55 special_uops = sorted([x for x in self.ast.toposort() if x.op is Ops.SPECIAL], key=lambda x: x.arg)
56 special_ops = [colored(str(x.vmax+1), "blue" if x.arg[0] == "g" else "cyan") for x in special_uops]
57 name = k_type + colored('_', 'BLACK').join(['']+special_ops+[colored(x.src[0].render(), color) for x,color in zip(self.rngs, self.colors())])
58 Scheduler.kernel_cnt[(function_name := to_function_name(name))] += 1
59 num = f"n{Scheduler.kernel_cnt[function_name]-1}" if Scheduler.kernel_cnt[function_name] > 1 else ""
60 name += colored(num, 'BLACK')
61 self.ast = graph_rewrite(self.ast, pm_flatten_range, name="flatten range")
62 return self.ast.replace(arg=KernelInfo(name=name, applied_opts=tuple(self.applied_opts), dont_use_locals=self.dont_use_locals), tag=1)
63
64 def _output_rngs(self) -> list[UOp]:
65 return flatten([[r for r in UOp.sink(*s.src[1:]).ranges if r.arg[-1] != AxisType.REDUCE] for s in self.ast.src if s.op is Ops.END])
66 def _globalizable_rngs(self) -> list[UOp]:
67 ret = [r for r in self._output_rngs() if r.arg[-1] == AxisType.LOOP]
68 # exclude any output ranges from global that don't appear in all BUFFERIZE
69 for x in self.ast.toposort():
70 if x.op is Ops.STAGE:
71 ret = [r for r in ret if r in x.ranges]
72 return ret

Callers 4

test_tc_upMethod · 0.90
test_max_upMethod · 0.90
copyMethod · 0.85
apply_optsFunction · 0.85

Calls

no outgoing calls

Tested by 2

test_tc_upMethod · 0.72
test_max_upMethod · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…