MCPcopy Index your code
hub / github.com/tinygrad/tinygrad / jit_model

Function jit_model

extra/export_model.py:47–59  ·  view source on GitHub ↗
(model, *args)

Source from the content-addressed store, hash-verified

45 return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save
46
47def jit_model(model, *args) -> Tuple[UOp, List[Buffer]]:
48 assert hasattr(model, "forward") or callable(model), "model needs a forward function"
49 @TinyJit
50 def run(*x):
51 out = model.forward(*x) if hasattr(model, "forward") else model(*x)
52 assert isinstance(out, (tuple, list, Tensor)), "model output must be a Tensor, tuple, or a list of Tensors for export"
53 out = [out] if isinstance(out, Tensor) else out
54 return [o.realize() for o in out]
55
56 # run twice to trigger JIT capture
57 for _ in range(2): the_output = run(*args)
58 assert run.captured is not None
59 return run.captured.linear, [o.uop.base.realized for o in the_output]
60
61def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]],
62 bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str], weight_names={}, model_name="model", symbolic_vars={}, wasm=False) -> str:

Callers 3

compile_onnx_modelFunction · 0.90
compile_stepFunction · 0.90
export_modelFunction · 0.85

Calls 1

runFunction · 0.70

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…