MCPcopy
hub / github.com/tinygrad/tinygrad / hand_coded_optimizations

Function hand_coded_optimizations

tinygrad/codegen/opt/heuristic.py:8–191  ·  view source on GitHub ↗

Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false. Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N). Keyword arguments: use

(k:Scheduler)

Source from the content-addressed store, hash-verified

6from tinygrad.codegen.opt.postrange import Scheduler
7
8def hand_coded_optimizations(k:Scheduler) -> Scheduler:
9 # first try the tensor cores
10 """ Attempts to apply a tensor core optimization to the kernel. If one exists and applies properly, return true, otherwise return false.
11 Tensor cores are optimized instructions that matrix multiply-accumulate across a wave of threads: D(M, N) = A(M, K) * B(K, N) + C(M, N).
12
13 Keyword arguments:
14 use_tensor_cores -- controls how tensor cores are applied (default 1)
15 0: will disable any tensor core matching
16 1: enable tensor cores
17 2: apply tensor core shape but don't use UOp.WMMA
18 extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None)
19 tc_select -- specifies which tensor core(s) to use for optimization (default -1)
20 -1: iterates through all available tensor cores in order and uses the first one that matches the requirements (dims and dtypes)
21 [0-N]: uses only the n'th tensor core available; useful for search
22 tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise)
23 0: applies to only kernels with a single reduce axis and direct Ops.LOAD into Ops.MUL
24 1: allows kernels with multiple reduce axes and also multiplication of Ops.CAST'd buffers
25 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed
26 """
27 # NOTE: unless TC_OPT is > 0, we only trigger tensor cores if there's only one reduce axis
28 if USE_TC > 0 and (len(k.axes_of(AxisType.GROUP_REDUCE, AxisType.REDUCE)) == 1 or (TC_OPT.value >= 1)):
29 good_tc_opt = False
30 tk = k.copy()
31 try: # check TC first and apply hand-coded opts if successful
32 rngs = tk.apply_opt(Opt(OptOps.TC, 0, (TC_SELECT.value, TC_OPT.value, USE_TC.value)))
33 good_tc_opt = True
34 except KernelOptError:
35 pass
36 # skip hand-coded TC opts if AMX, upcasting will make kernel slower
37 if good_tc_opt and "AMX" not in k.ren.target.arch:
38 if rngs is not None:
39 for tc_dim in [1,0]: # attempt to upcast M and N
40 szs = [sz for sz in [5,4,3,2] if rngs[tc_dim].src[0].divides(sz) is not None]
41 if szs:
42 # set it to the replaced range
43 rngs[tc_dim] = tk.apply_opt(Opt(OptOps.UPCAST, tk.rngs.index(rngs[tc_dim]), szs[0]))[0]
44 if (szs := [sz for sz in [4,2] if rngs[0].src[0].divides(sz) is not None]): # attempt to local N
45 tk.apply_opt(Opt(OptOps.LOCAL, tk.rngs.index(rngs[0]), szs[0]))
46 return tk
47
48 # make a copy so it does not mutate the input
49 k = k.copy()
50
51 # upcast float4 images, this must be early so we don't accidentally add locals before the upcast
52 if IMAGE:
53 for buf_index,buf in enumerate(k.bufs):
54 if isinstance(buf.src[0].dtype, PtrDType) and ImageDType.valid_dims(buf.src[0].dtype, k.ren.target.arch):
55 # part of is_expanded
56 unit_stride_axes_mul_4 = [k.rngs.index(c) for c in k.bufs[buf_index].src[1].get_idx().split_uop(Ops.ADD) if
57 c.op is Ops.RANGE and (c.vmax+1)%4 == 0]
58 if len(unit_stride_axes_mul_4):
59 if (axis:=unit_stride_axes_mul_4[0]) in k.upcastable_dims:
60 k.apply_opt(Opt(OptOps.UPCAST, axis, 4))
61 elif axis in k.unrollable_dims:
62 k.apply_opt(Opt(OptOps.UNROLL, k.unrollable_dims.index(axis), 4))
63
64 # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
65 MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)

Callers 1

apply_optsFunction · 0.90

Calls 15

OptClass · 0.90
getenvFunction · 0.90
resolveFunction · 0.90
prodFunction · 0.90
axes_ofMethod · 0.80
apply_optMethod · 0.80
dividesMethod · 0.80
valid_dimsMethod · 0.80
split_uopMethod · 0.80
get_idxMethod · 0.80
ranges_ofMethod · 0.80
appendMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…