Executes the onnx model. Args: output_names: requested outputs by names, None for all feed_inputs: dictionary `{ input name: input value }` attributes: attributes value if the instance runs a FunctionProto intermediate: if True
(
self,
output_names,
feed_inputs: dict[str, Any],
attributes: dict[str, Any] | None = None,
intermediate: bool = False,
)
| 548 | ) |
| 549 | |
| 550 | def run( |
| 551 | self, |
| 552 | output_names, |
| 553 | feed_inputs: dict[str, Any], |
| 554 | attributes: dict[str, Any] | None = None, |
| 555 | intermediate: bool = False, |
| 556 | ) -> dict[str, Any] | list[Any]: |
| 557 | """Executes the onnx model. |
| 558 | |
| 559 | Args: |
| 560 | output_names: requested outputs by names, None for all |
| 561 | feed_inputs: dictionary `{ input name: input value }` |
| 562 | attributes: attributes value if the instance runs a |
| 563 | FunctionProto |
| 564 | intermediate: if True, the function returns all the results, |
| 565 | final ones and intermediates one in a same dictionary, |
| 566 | if False, only the final results are returned in a list |
| 567 | |
| 568 | Returns: |
| 569 | list of requested outputs if intermediate is False, |
| 570 | named results in a dictionary otherwise |
| 571 | """ |
| 572 | if output_names is None: |
| 573 | output_names = self.output_names |
| 574 | if isinstance(self.proto_, FunctionProto) and attributes is None: |
| 575 | raise TypeError |
| 576 | |
| 577 | # step 1: inputs and initializers |
| 578 | results = {"": None} # optional input |
| 579 | results.update(self.rt_inits_) # type: ignore[arg-type] |
| 580 | results.update(feed_inputs) |
| 581 | for k, v in self.rt_inits_.items(): |
| 582 | self._log(2, " +C %s: %s", k, v) # type: ignore[arg-type] |
| 583 | for k, v in feed_inputs.items(): |
| 584 | self._log(2, " +I %s: %s", k, v) # type: ignore[arg-type] |
| 585 | |
| 586 | # step 2: execute nodes |
| 587 | for node in self.rt_nodes_: |
| 588 | self._log(1, "%s(%s) -> %s", node.op_type, node.input, node.output) |
| 589 | for i in node.input: |
| 590 | if i not in results: |
| 591 | raise RuntimeError( |
| 592 | f"Unable to find input {i!r} in known results {sorted(results)}, " |
| 593 | f"self.rt_inits_ has {sorted(self.rt_inits_)}, " |
| 594 | f"feed_inputs has {sorted(feed_inputs)}." |
| 595 | ) |
| 596 | inputs = [results[i] for i in node.input] |
| 597 | linked_attributes = {} |
| 598 | if node.has_linked_attribute and attributes: |
| 599 | linked_attributes["linked_attributes"] = attributes |
| 600 | if node.need_context(): |
| 601 | outputs = node.run(*inputs, context=results, **linked_attributes) |
| 602 | else: |
| 603 | outputs = node.run(*inputs, **linked_attributes) |
| 604 | for name, value in zip(node.output, outputs, strict=False): |
| 605 | self._log(2, " + %s: %s", name, value) # type: ignore[arg-type] |
| 606 | results[name] = value |
| 607 |