(linear:UOp, beam:int|None=None, validate=False)
| 244 | pm_exec = pm_hcq_exec + pm_exec |
| 245 | |
| 246 | def compile_linear(linear:UOp, beam:int|None=None, validate=False) -> UOp: |
| 247 | if validate: linear = graph_rewrite(linear, pm_validate, name="validate", walk=True) |
| 248 | if (beam_val:=BEAM.value if beam is None else beam) >= 1: linear = graph_rewrite(linear, pm_beam, ctx=beam_val, walk=True) |
| 249 | linear = graph_rewrite(linear, pm_compile, name="precompile kernels", walk=True) |
| 250 | return graph_rewrite(linear, pm_optimize_local_size, name="optimize local size", walk=True) |
| 251 | |
| 252 | def run_linear(linear:UOp, var_vals:dict[str, int]|None=None, input_uops:tuple[UOp, ...]=(), update_stats=True, jit=False, wait=False): |
| 253 | if not jit: linear = compile_linear(linear, validate=VALIDATE_WITH_CPU) |
searching dependent graphs…