MCPcopy
hub / github.com/tinygrad/tinygrad / apply_opt

Method apply_opt

tinygrad/codegen/opt/postrange.py:125–217  ·  view source on GitHub ↗
(self, opt:Opt, append_opt:bool=True)

Source from the content-addressed store, hash-verified

123 except IndexError as e: raise KernelOptError from e
124
125 def apply_opt(self, opt:Opt, append_opt:bool=True):
126 if opt.op is OptOps.NOLOCALS:
127 check(all(x not in {AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE} for x in self.axis_types), "no locals can't have locals")
128 if append_opt: self.applied_opts.append(opt)
129 self.dont_use_locals = True
130 return
131
132 if opt.op in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}:
133 check(self.ren.has_local, "locals needed for opt")
134
135 rng = self.rngs[real_axis] if (real_axis:=self.real_axis(opt.op, opt.axis)) >= 0 else UOp(Ops.NOOP)
136
137 opt_to_at = {
138 OptOps.LOCAL: AxisType.LOCAL, OptOps.UPCAST: AxisType.UPCAST,
139 OptOps.UNROLL: AxisType.UNROLL, OptOps.GROUP: AxisType.GROUP_REDUCE,
140 OptOps.GROUPTOP: AxisType.GROUP_REDUCE, OptOps.THREAD: AxisType.THREAD}
141
142 ret = None
143 if opt.op in opt_to_at:
144 amt:int = int(rng.vmax+1) if opt.arg == 0 else cast(int, opt.arg)
145
146 # copied from kernel.py. prevents METAL compiler hangs
147 if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \
148 (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
149 upcast_local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST, AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)])
150 smem_sz = amt*upcast_local_sz*self.reduceop.dtype.itemsize
151 check(smem_sz <= self.ren.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.ren.shared_max}")
152 if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP}):
153 # We currently dont support a group within another rudece, TODO: fix if-contexts
154 reduce = [u for u in self.ast.backward_slice if u.op is Ops.REDUCE and rng in merge_dicts([r.ranges for r in u.src[1:]])][0]
155 check(not any(u.arg[-1] in (AxisType.REDUCE, AxisType.UNROLL, AxisType.GROUP_REDUCE) for u in reduce.ranges),
156 "cannot have a GROUP_REDUCE inside another reduce")
157
158 if opt.op is OptOps.UNROLL:
159 check(amt <= 32, "don't unroll more than 32")
160 check(rng.arg[-1] in {AxisType.GROUP_REDUCE, AxisType.REDUCE}, "unroll is for GROUP_REDUCE/REDUCE")
161 if opt.op is OptOps.UPCAST:
162 check((self.ren is not None and self.ren.target.device == "DSP") or amt <= 16, "don't upcast more than 16")
163 check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP}, f"upcast is for GLOBAL/LOCAL/LOOP, not {rng.arg[-1]}")
164 if opt.op is OptOps.LOCAL:
165 check(not self.dont_use_locals, "can't use locals")
166 check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOOP}, "local is for globals")
167 if opt.op is OptOps.THREAD:
168 check(self.ren is not None and self.ren.has_threads, "target does not support threads")
169 check(self.ren is not None and self.ren.global_max is not None and amt <= self.ren.global_max[0], "too many threads")
170 check(all(x is not AxisType.THREAD for x in self.axis_types), "already threaded")
171 check(rng in self._globalizable_rngs(), "can't apply range to this dim")
172 if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}:
173 check(all(x.op is not OptOps.TC for x in self.applied_opts), "no grouping with tensor cores") # TODO: why is this wrong?
174 check(not self.dont_use_locals, "can't use locals")
175 check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce")
176 ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD})
177 elif opt.op is OptOps.TC:
178 check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps
179 check(opt.axis is not None, "tensor core opts must have an axis")
180 check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg")
181 check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.ren.tensor_cores), "tensor core opts must have valid tc_select")
182 check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt")

Callers 6

_apply_tc_optMethod · 0.95
apply_optsFunction · 0.95
test_tc_upMethod · 0.95
hand_coded_optimizationsFunction · 0.80
get_kernel_actionsFunction · 0.80
beam_searchFunction · 0.80

Calls 15

real_axisMethod · 0.95
axes_ofMethod · 0.95
_globalizable_rngsMethod · 0.95
shift_toMethod · 0.95
_apply_tc_optMethod · 0.95
checkFunction · 0.90
UOpClass · 0.90
prodFunction · 0.90
merge_dictsFunction · 0.90
KernelOptErrorClass · 0.90
round_upFunction · 0.90
graph_rewriteFunction · 0.90

Tested by 1

test_tc_upMethod · 0.76