MCPcopy
hub / github.com/tinygrad/tinygrad / compile_net

Function compile_net

extra/export_model.py:20–45  ·  view source on GitHub ↗
(linear:UOp, output_bufs:List[Buffer])

Source from the content-addressed store, hash-verified

18 return (u for u in linear.toposort(gate=lambda x: x.op not in _KERNEL_ASTS) if u.op is Ops.CALL and u.src[0].op in _KERNEL_ASTS)
19
20def compile_net(linear:UOp, output_bufs:List[Buffer]) -> Tuple[Dict[str,str], List, Dict[str,Tuple[int,DType,int]], Dict[str,Buffer]]:
21 output_name = {id(b): f"output{i}" for i, b in enumerate(output_bufs)}
22 functions, bufs, bufs_to_save, statements, n = {}, {}, {}, [], 0
23
24 def name_of(bu:UOp, is_out:bool) -> str:
25 nonlocal n
26 if bu.op is Ops.PARAM: key, name, size = ("in", bu.arg), f"input{bu.arg}", prod(bu.shape)*bu.dtype.itemsize
27 else:
28 b = bu.buffer
29 key, size = (id(b.base), b.offset, b.size, b.dtype), b.size*b.dtype.itemsize
30 if key in bufs: return bufs[key][0]
31 if (name:=output_name.get(id(b))) is None:
32 name, n = f"buf_{n}", n+1
33 if not is_out: bufs_to_save[name] = b
34 bufs[key] = (name, size, bu.dtype, key)
35 return name
36
37 for call in iter_kernel_calls(linear):
38 arg_uops = [b for b in call.src[1:] if b.op is not Ops.BIND]
39 prg = to_program(call.src[0], Device[arg_uops[0].device].renderer)
40 info = prg.arg
41 functions[info.function_name] = prg.src[3].arg
42 cargs = [name_of(bu, i == 0) for i, bu in enumerate(arg_uops)] + [v for v in info.vars if v.op is Ops.DEFINE_VAR]
43 statements.append((info.function_name, cargs, info.global_size, info.local_size))
44
45 return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save
46
47def jit_model(model, *args) -> Tuple[UOp, List[Buffer]]:
48 assert hasattr(model, "forward") or callable(model), "model needs a forward function"

Callers 3

compile_onnx_modelFunction · 0.90
compile_stepFunction · 0.90
export_modelFunction · 0.85

Calls 5

to_programFunction · 0.90
idFunction · 0.85
iter_kernel_callsFunction · 0.85
name_ofFunction · 0.85
appendMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…