MCPcopy
hub / github.com/tinygrad/tinygrad / create_linear_with_vars

Function create_linear_with_vars

tinygrad/schedule/__init__.py:122–147  ·  view source on GitHub ↗
(big_sink:UOp)

Source from the content-addressed store, hash-verified

120
121@track_rewrites(lambda _,ret: f"Schedule {pluralize('Kernel', len(ret[0].src))}")
122def create_linear_with_vars(big_sink:UOp) -> tuple[UOp, dict[str, int]]:
123 # big_sink srcs are all the Tensors
124 linear_call = graph_rewrite(big_sink, pm_schedule, name="schedule to linear", enter_calls=True)
125
126 # this recursively resolves the linear_call and allocates buffers
127 linear = graph_rewrite(linear_call, pm_resolve_linear_call, name="resolve linear call")
128
129 # vars used in the schedule
130 used_vars = set().union(*[{v.expr for v in si.src[0].variables()} for si in linear.src])
131 # get var_vals
132 var_vals: dict[str, int] = {}
133 for b in big_sink.src[1:]:
134 if b.op is Ops.BIND:
135 nm = b.src[0].expr
136 if nm not in used_vars: continue
137 val = b.src[1].arg
138 if var_vals.get(nm, val) != val: raise RuntimeError(f"bind mismatch on {nm}, {var_vals[nm]} != {val}")
139 var_vals[nm] = val
140
141 # jit captures this schedule, no need to execute.
142 if len(capturing) and CAPTURING:
143 capturing[0].add_linear(linear, var_vals)
144 return UOp(Ops.LINEAR, src=()), var_vals
145
146 held_bufs = ({b for b in linear_call.src[1:] if b.op is Ops.BUFFER} if linear_call.op is Ops.CALL else set())
147 return memory_plan_rewrite(linear, held_bufs), var_vals

Callers 1

linear_with_varsMethod · 0.90

Calls 6

graph_rewriteFunction · 0.90
UOpClass · 0.90
memory_plan_rewriteFunction · 0.90
variablesMethod · 0.80
getMethod · 0.45
add_linearMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…