(session_bytes, session_options, names)
| 13 | * The wrapper function for running the ONNX inference session. |
| 14 | */ |
| 15 | const wrap = async (session_bytes, session_options, names) => { |
| 16 | const session = await createInferenceSession(new Uint8Array(session_bytes), session_options); |
| 17 | |
| 18 | return /** @type {any} */ ( |
| 19 | async (/** @type {Record<string, Tensor>} */ inputs) => { |
| 20 | const proxied = isONNXProxy(); |
| 21 | const ortFeed = Object.fromEntries( |
| 22 | Object.entries(inputs).map(([k, v]) => [k, (proxied ? v.clone() : v).ort_tensor]), |
| 23 | ); |
| 24 | const outputs = await runInferenceSession(session, ortFeed); |
| 25 | if (Array.isArray(names)) { |
| 26 | return names.map((n) => new Tensor(outputs[n])); |
| 27 | } else { |
| 28 | return new Tensor(outputs[/** @type {string} */ (names)]); |
| 29 | } |
| 30 | } |
| 31 | ); |
| 32 | }; |
| 33 | |
| 34 | // In-memory registry of initialized ONNX operators |
| 35 | export class TensorOpRegistry { |
no test coverage detected