(*x)
| 48 | assert hasattr(model, "forward") or callable(model), "model needs a forward function" |
| 49 | @TinyJit |
| 50 | def run(*x): |
| 51 | out = model.forward(*x) if hasattr(model, "forward") else model(*x) |
| 52 | assert isinstance(out, (tuple, list, Tensor)), "model output must be a Tensor, tuple, or a list of Tensors for export" |
| 53 | out = [out] if isinstance(out, Tensor) else out |
| 54 | return [o.realize() for o in out] |
| 55 | |
| 56 | # run twice to trigger JIT capture |
| 57 | for _ in range(2): the_output = run(*args) |