| 123 | except IndexError as e: raise KernelOptError from e |
| 124 | |
| 125 | def apply_opt(self, opt:Opt, append_opt:bool=True): |
| 126 | if opt.op is OptOps.NOLOCALS: |
| 127 | check(all(x not in {AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE} for x in self.axis_types), "no locals can't have locals") |
| 128 | if append_opt: self.applied_opts.append(opt) |
| 129 | self.dont_use_locals = True |
| 130 | return |
| 131 | |
| 132 | if opt.op in {OptOps.LOCAL, OptOps.GROUP, OptOps.GROUPTOP}: |
| 133 | check(self.ren.has_local, "locals needed for opt") |
| 134 | |
| 135 | rng = self.rngs[real_axis] if (real_axis:=self.real_axis(opt.op, opt.axis)) >= 0 else UOp(Ops.NOOP) |
| 136 | |
| 137 | opt_to_at = { |
| 138 | OptOps.LOCAL: AxisType.LOCAL, OptOps.UPCAST: AxisType.UPCAST, |
| 139 | OptOps.UNROLL: AxisType.UNROLL, OptOps.GROUP: AxisType.GROUP_REDUCE, |
| 140 | OptOps.GROUPTOP: AxisType.GROUP_REDUCE, OptOps.THREAD: AxisType.THREAD} |
| 141 | |
| 142 | ret = None |
| 143 | if opt.op in opt_to_at: |
| 144 | amt:int = int(rng.vmax+1) if opt.arg == 0 else cast(int, opt.arg) |
| 145 | |
| 146 | # copied from kernel.py. prevents METAL compiler hangs |
| 147 | if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or \ |
| 148 | (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})): |
| 149 | upcast_local_sz = prod([self.full_shape[a] for a in self.axes_of(AxisType.UPCAST, AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE)]) |
| 150 | smem_sz = amt*upcast_local_sz*self.reduceop.dtype.itemsize |
| 151 | check(smem_sz <= self.ren.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.ren.shared_max}") |
| 152 | if self.reduceop is not None and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP}): |
| 153 | # We currently dont support a group within another rudece, TODO: fix if-contexts |
| 154 | reduce = [u for u in self.ast.backward_slice if u.op is Ops.REDUCE and rng in merge_dicts([r.ranges for r in u.src[1:]])][0] |
| 155 | check(not any(u.arg[-1] in (AxisType.REDUCE, AxisType.UNROLL, AxisType.GROUP_REDUCE) for u in reduce.ranges), |
| 156 | "cannot have a GROUP_REDUCE inside another reduce") |
| 157 | |
| 158 | if opt.op is OptOps.UNROLL: |
| 159 | check(amt <= 32, "don't unroll more than 32") |
| 160 | check(rng.arg[-1] in {AxisType.GROUP_REDUCE, AxisType.REDUCE}, "unroll is for GROUP_REDUCE/REDUCE") |
| 161 | if opt.op is OptOps.UPCAST: |
| 162 | check((self.ren is not None and self.ren.target.device == "DSP") or amt <= 16, "don't upcast more than 16") |
| 163 | check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOCAL, AxisType.LOOP}, f"upcast is for GLOBAL/LOCAL/LOOP, not {rng.arg[-1]}") |
| 164 | if opt.op is OptOps.LOCAL: |
| 165 | check(not self.dont_use_locals, "can't use locals") |
| 166 | check(rng.arg[-1] in {AxisType.GLOBAL, AxisType.LOOP}, "local is for globals") |
| 167 | if opt.op is OptOps.THREAD: |
| 168 | check(self.ren is not None and self.ren.has_threads, "target does not support threads") |
| 169 | check(self.ren is not None and self.ren.global_max is not None and amt <= self.ren.global_max[0], "too many threads") |
| 170 | check(all(x is not AxisType.THREAD for x in self.axis_types), "already threaded") |
| 171 | check(rng in self._globalizable_rngs(), "can't apply range to this dim") |
| 172 | if opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: |
| 173 | check(all(x.op is not OptOps.TC for x in self.applied_opts), "no grouping with tensor cores") # TODO: why is this wrong? |
| 174 | check(not self.dont_use_locals, "can't use locals") |
| 175 | check(rng.arg[-1] == AxisType.REDUCE, "group is for reduce") |
| 176 | ret = self.shift_to(rng, amt, opt_to_at[opt.op], top=opt.op in {OptOps.GROUPTOP, OptOps.THREAD}) |
| 177 | elif opt.op is OptOps.TC: |
| 178 | check(len(self.applied_opts) == 0, "tensor core opts must be first") # TODO: remove the need for this by having warps |
| 179 | check(opt.axis is not None, "tensor core opts must have an axis") |
| 180 | check(opt.arg is not None and isinstance(opt.arg, tuple) and len(opt.arg) == 3, "tensor core opts must have valid arg") |
| 181 | check(-1 <= (tc_select:=cast(tuple, opt.arg)[0]) < len(self.ren.tensor_cores), "tensor core opts must have valid tc_select") |
| 182 | check(0 <= (tc_opt:=cast(tuple, opt.arg)[1]) <= 2, "tensor core opts must have valid tc_opt") |