(self, linear:UOp, input_uops:tuple[UOp, ...]=())
| 89 | |
| 90 | class GraphRunner: |
| 91 | def __init__(self, linear:UOp, input_uops:tuple[UOp, ...]=()): |
| 92 | self.linear = linear.src[0] |
| 93 | self.calls: list[tuple[int, UOp, list[Buffer], dict[str, int]]] = [] |
| 94 | self.runtimes: list[Any|None] = [] |
| 95 | self.uop_replace: list[list[tuple[int, int]]] = [] |
| 96 | for call in self.linear.src: |
| 97 | replace = [(p, b.arg) for p, b in enumerate(get_call_arg_uops(call)) if b.op is Ops.PARAM] |
| 98 | for dev_idx, (bufs, device_vars) in enumerate(unwrap_multi(call, resolve_params(call, input_uops))): |
| 99 | self.calls.append((dev_idx, call.src[0], [b.ensure_allocated() for b in bufs], device_vars)) |
| 100 | self.runtimes.append(get_runtime(bufs[0].device, call.src[0]) if call.src[0].op is Ops.PROGRAM else None) |
| 101 | self.uop_replace.append(replace) |
| 102 | |
| 103 | self.var_vals_replace:dict[int, list[tuple[int, int]]] = {} |
| 104 | self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {} |
| 105 | self.launch_dims_base:dict[int, tuple[tuple[int|float, ...], tuple[int, ...]]] = {} |
| 106 | |
| 107 | def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim) |
| 108 | |
| 109 | crs = [(j, self.calls[j][1].arg, self.calls[j][3]) for j in range(len(self.calls)) if self.calls[j][1].op is Ops.PROGRAM] |
| 110 | self.vars = sorted({v.expr for _,p,dv in crs for v in p.vars if v.expr not in dv | p.runtimevars}) |
| 111 | self.symbolic_dims = dedup(tuple(d) for _,p,_ in crs for d in (p.local_size, p.global_size) if d and is_sym_dim(d)) |
| 112 | |
| 113 | def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None |
| 114 | |
| 115 | for j,p,dv in crs: |
| 116 | if (replace:=[(i, self.vars.index(v.expr)) for i, v in enumerate(p.vars) if v.expr not in dv | p.runtimevars]): |
| 117 | self.var_vals_replace[j] = replace |
| 118 | global_dim_idx, local_dim_idx = find_symbolic_dim(p.global_size), find_symbolic_dim(p.local_size) |
| 119 | if global_dim_idx is not None or local_dim_idx is not None: |
| 120 | self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx) |
| 121 | assert p.local_size is not None |
| 122 | self.launch_dims_base[j] = (tuple(p.global_size), tuple(p.local_size)) |
| 123 | |
| 124 | estimates = sum((estimate_uop(call) for call in self.linear.src), Estimates()) |
| 125 | |
| 126 | # used in MultiGraphRunner. tracks (offset, end, dep) ranges per base buffer id to handle suballocated buffers correctly. |
| 127 | self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list) |
| 128 | self.r_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list) |
| 129 | |
| 130 | self.device, self.estimates = self.calls[0][2][0].device.split(":")[0], estimates.simplify() |
| 131 | |
| 132 | def __call__(self, input_uops:tuple[UOp, ...], var_vals:dict[str, int], wait=False) -> float|None: raise NotImplementedError("override this") |
| 133 |
nothing calls this directly
no test coverage detected