MCPcopy
hub / github.com/tinygrad/tinygrad / get_kernel_actions

Function get_kernel_actions

tinygrad/codegen/opt/search.py:90–111  ·  view source on GitHub ↗
(s:Scheduler, include_0=True, max_up:int|None=None)

Source from the content-addressed store, hash-verified

88
89# get dictionary of all possible actions
90def get_kernel_actions(s:Scheduler, include_0=True, max_up:int|None=None) -> dict[int, Scheduler]:
91 acted, max_up, max_lcl = {0:s} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256) if max_up is None else max_up, getenv("BEAM_LOCAL_MAX", 1024)
92 kernel_actions = actions.copy()
93
94 for i,a in enumerate(kernel_actions):
95 if a.axis is not None and a.op is not OptOps.TC:
96 try: ax = s.real_axis(a.op, a.axis)
97 except KernelOptError: continue
98 if (ax >= s.shape_len) or (s.full_shape[ax] == a.arg and Opt(a.op, a.axis, 0) in kernel_actions): continue
99 s2 = s.copy()
100 try:
101 s2.apply_opt(a)
102 up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if hasattr(s2, 'tensor_core') and (tc:=s2.tensor_core) else 1
103 for x,t in zip(s2.full_shape, s2.axis_types):
104 if t in (AxisType.UPCAST, AxisType.UNROLL): up *= x
105 elif t in (AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE): lcl *= x
106 if up//tc_up > max_up or lcl > max_lcl:
107 if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many upcast/local. {up//tc_up=}, {max_up=}, {lcl=}, {max_lcl=}")
108 continue
109 acted[i+1] = s2
110 except KernelOptError: pass
111 return acted
112
113beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG")
114def beam_search(s:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value):

Callers 3

test_tc_upMethod · 0.90
test_max_upMethod · 0.90
beam_searchFunction · 0.85

Calls 6

getenvFunction · 0.90
OptClass · 0.90
prodFunction · 0.90
real_axisMethod · 0.80
apply_optMethod · 0.80
copyMethod · 0.45

Tested by 2

test_tc_upMethod · 0.72
test_max_upMethod · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…