(model, *args)
| 45 | return functions, statements, {name:(size, dtype, key) for name, size, dtype, key in bufs.values()}, bufs_to_save |
| 46 | |
| 47 | def jit_model(model, *args) -> Tuple[UOp, List[Buffer]]: |
| 48 | assert hasattr(model, "forward") or callable(model), "model needs a forward function" |
| 49 | @TinyJit |
| 50 | def run(*x): |
| 51 | out = model.forward(*x) if hasattr(model, "forward") else model(*x) |
| 52 | assert isinstance(out, (tuple, list, Tensor)), "model output must be a Tensor, tuple, or a list of Tensors for export" |
| 53 | out = [out] if isinstance(out, Tensor) else out |
| 54 | return [o.realize() for o in out] |
| 55 | |
| 56 | # run twice to trigger JIT capture |
| 57 | for _ in range(2): the_output = run(*args) |
| 58 | assert run.captured is not None |
| 59 | return run.captured.linear, [o.uop.base.realized for o in the_output] |
| 60 | |
| 61 | def export_model_clang(functions:Dict[str,str], statements:Dict[str,Tuple[str,int,int]], bufs:Dict[str,Tuple[str,int,int]], |
| 62 | bufs_to_save:Dict[str,Tensor], input_names:List[str], output_names:List[str], weight_names={}, model_name="model", symbolic_vars={}, wasm=False) -> str: |
no test coverage detected
searching dependent graphs…