Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false. Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N). Keyword arguments: use
(k:Scheduler)
| 6 | from tinygrad.codegen.opt.postrange import Scheduler |
| 7 | |
| 8 | def hand_coded_optimizations(k:Scheduler) -> Scheduler: |
| 9 | # first try the tensor cores |
| 10 | """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false. |
| 11 | Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N). |
| 12 | |
| 13 | Keyword arguments: |
| 14 | use_tensor_cores -- controls how tensor cores are applied (default 1) |
| 15 | 0: will disable any tensor core matching |
| 16 | 1: enable tensor cores |
| 17 | 2: apply tensor core shape but don't use UOp.WMMA |
| 18 | extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None) |
| 19 | tc_select -- specifies which tensor core(s) to use for optimization (default -1) |
| 20 | -1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes) |
| 21 | [0-N]: uses only the n'th tensor core available; useful for search |
| 22 | tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise) |
| 23 | 0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL |
| 24 | 1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers |
| 25 | 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed |
| 26 | """ |
| 27 | # NOTE: unless TC_OPT is > 0, we only trigger tensor cores if there's only one reduce axis |
| 28 | if USE_TC > 0 and (len(k.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (TC_OPT.value >= 1)): |
| 29 | good_tc_opt = False |
| 30 | tk = k.copy() |
| 31 | try: # check TC first and apply hand-coded opts if successful |
| 32 | rngs = tk.apply_opt(Opt(OptOps.TC, 0, (TC_SELECT.value, TC_OPT.value, USE_TC.value))) |
| 33 | good_tc_opt = True |
| 34 | except KernelOptError: |
| 35 | pass |
| 36 | # skip hand-coded TC opts if AMX, upcasting will make kernel slower |
| 37 | if good_tc_opt and "AMX" not in k.ren.target.arch: |
| 38 | if rngs is not None: |
| 39 | for tc_dim in [1,0]: # attempt to upcast M and N |
| 40 | szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None] |
| 41 | if szs: |
| 42 | # set it to the replaced range |
| 43 | rngs[tc_dim] = tk.apply_opt(Opt(OptOps.UPCAST, tk.rngs.index(rngs[tc_dim]), szs[0]))[0] |
| 44 | if (szs := [sz for sz in [4,2] if rngs[0].src[0].divides(sz) is not None]): # attempt to local N |
| 45 | tk.apply_opt(Opt(OptOps.LOCAL, tk.rngs.index(rngs[0]), szs[0])) |
| 46 | return tk |
| 47 | |
| 48 | # make a copy so it does not mutate the input |
| 49 | k = k.copy() |
| 50 | |
| 51 | # upcast float4 images, this must be early so we don't accidentally add locals before the upcast |
| 52 | if IMAGE: |
| 53 | for buf_index,buf in enumerate(k.bufs): |
| 54 | if isinstance(buf.src[0].dtype, PtrDType) and ImageDType.valid_dims(buf.src[0].dtype, k.ren.target.arch): |
| 55 | # part of is_expanded |
| 56 | unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if |
| 57 | c.op is Ops.RANGE and (c.vmax+1)%4 == 0] |
| 58 | if len(unit_stride_axes_mul_4): |
| 59 | if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims: |
| 60 | k.apply_opt(Opt(OptOps.UPCAST, axis, 4)) |
| 61 | elif axis in k.unrollable_dims: |
| 62 | k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims.index(axis), 4)) |
| 63 | |
| 64 | # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat |
| 65 | MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) |
no test coverage detected
searching dependent graphs…