MCPcopy
hub / github.com/tinygrad/tinygrad / __init__

Method __init__

tinygrad/engine/jit.py:91–130  ·  view source on GitHub ↗
(self, linear:UOp, input_uops:tuple[UOp, ...]=())

Source from the content-addressed store, hash-verified

89
90class GraphRunner:
91 def __init__(self, linear:UOp, input_uops:tuple[UOp, ...]=()):
92 self.linear = linear.src[0]
93 self.calls: list[tuple[int, UOp, list[Buffer], dict[str, int]]] = []
94 self.runtimes: list[Any|None] = []
95 self.uop_replace: list[list[tuple[int, int]]] = []
96 for call in self.linear.src:
97 replace = [(p, b.arg) for p, b in enumerate(get_call_arg_uops(call)) if b.op is Ops.PARAM]
98 for dev_idx, (bufs, device_vars) in enumerate(unwrap_multi(call, resolve_params(call, input_uops))):
99 self.calls.append((dev_idx, call.src[0], [b.ensure_allocated() for b in bufs], device_vars))
100 self.runtimes.append(get_runtime(bufs[0].device, call.src[0]) if call.src[0].op is Ops.PROGRAM else None)
101 self.uop_replace.append(replace)
102
103 self.var_vals_replace:dict[int, list[tuple[int, int]]] = {}
104 self.launch_dims_replace:dict[int, tuple[int|None, int|None]] = {}
105 self.launch_dims_base:dict[int, tuple[tuple[int|float, ...], tuple[int, ...]]] = {}
106
107 def is_sym_dim(dim) -> bool: return not all(isinstance(d, (int, float)) for d in dim)
108
109 crs = [(j, self.calls[j][1].arg, self.calls[j][3]) for j in range(len(self.calls)) if self.calls[j][1].op is Ops.PROGRAM]
110 self.vars = sorted({v.expr for _,p,dv in crs for v in p.vars if v.expr not in dv | p.runtimevars})
111 self.symbolic_dims = dedup(tuple(d) for _,p,_ in crs for d in (p.local_size, p.global_size) if d and is_sym_dim(d))
112
113 def find_symbolic_dim(dim): return self.symbolic_dims.index(tuple(dim)) if dim is not None and tuple(dim) in self.symbolic_dims else None
114
115 for j,p,dv in crs:
116 if (replace:=[(i, self.vars.index(v.expr)) for i, v in enumerate(p.vars) if v.expr not in dv | p.runtimevars]):
117 self.var_vals_replace[j] = replace
118 global_dim_idx, local_dim_idx = find_symbolic_dim(p.global_size), find_symbolic_dim(p.local_size)
119 if global_dim_idx is not None or local_dim_idx is not None:
120 self.launch_dims_replace[j] = (global_dim_idx, local_dim_idx)
121 assert p.local_size is not None
122 self.launch_dims_base[j] = (tuple(p.global_size), tuple(p.local_size))
123
124 estimates = sum((estimate_uop(call) for call in self.linear.src), Estimates())
125
126 # used in MultiGraphRunner. tracks (offset, end, dep) ranges per base buffer id to handle suballocated buffers correctly.
127 self.w_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
128 self.r_dependency_map: dict[int, list[tuple[int, int, Any]]] = collections.defaultdict(list)
129
130 self.device, self.estimates = self.calls[0][2][0].device.split(":")[0], estimates.simplify()
131
132 def __call__(self, input_uops:tuple[UOp, ...], var_vals:dict[str, int], wait=False) -> float|None: raise NotImplementedError("override this")
133

Callers

nothing calls this directly

Calls 12

get_call_arg_uopsFunction · 0.90
unwrap_multiFunction · 0.90
resolve_paramsFunction · 0.90
get_runtimeFunction · 0.90
dedupFunction · 0.90
estimate_uopFunction · 0.90
EstimatesClass · 0.85
appendMethod · 0.80
ensure_allocatedMethod · 0.80
splitMethod · 0.80
indexMethod · 0.45
simplifyMethod · 0.45

Tested by

no test coverage detected