| 88 | |
| 89 | # get dictionary of all possible actions |
| 90 | def get_kernel_actions(s:Scheduler, include_0=True, max_up:int|None=None) -> dict[int, Scheduler]: |
| 91 | acted, max_up, max_lcl = {0:s} if include_0 else {}, getenv("BEAM_UPCAST_MAX", 256) if max_up is None else max_up, getenv("BEAM_LOCAL_MAX", 1024) |
| 92 | kernel_actions = actions.copy() |
| 93 | |
| 94 | for i,a in enumerate(kernel_actions): |
| 95 | if a.axis is not None and a.op is not OptOps.TC: |
| 96 | try: ax = s.real_axis(a.op, a.axis) |
| 97 | except KernelOptError: continue |
| 98 | if (ax >= s.shape_len) or (s.full_shape[ax] == a.arg and Opt(a.op, a.axis, 0) in kernel_actions): continue |
| 99 | s2 = s.copy() |
| 100 | try: |
| 101 | s2.apply_opt(a) |
| 102 | up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if hasattr(s2, 'tensor_core') and (tc:=s2.tensor_core) else 1 |
| 103 | for x,t in zip(s2.full_shape, s2.axis_types): |
| 104 | if t in (AxisType.UPCAST, AxisType.UNROLL): up *= x |
| 105 | elif t in (AxisType.WARP, AxisType.LOCAL, AxisType.GROUP_REDUCE): lcl *= x |
| 106 | if up//tc_up > max_up or lcl > max_lcl: |
| 107 | if getenv("BEAM_LOG_SURPASS_MAX"): print(f"too many upcast/local. {up//tc_up=}, {max_up=}, {lcl=}, {max_lcl=}") |
| 108 | continue |
| 109 | acted[i+1] = s2 |
| 110 | except KernelOptError: pass |
| 111 | return acted |
| 112 | |
| 113 | beam_pool, BEAM_DEBUG = None, getenv("BEAM_DEBUG") |
| 114 | def beam_search(s:Scheduler, rawbufs:list[Buffer], amt:int, allow_test_size=True, disable_cache=IGNORE_BEAM_CACHE.value): |