(
self,
inputs: Union[dict, list, torch.Tensor],
executing_order: List[Operation],
output_names:List[str] = None,
hooks: Dict[str, RuntimeHook] = None,
)
| 455 | ) |
| 456 | |
| 457 | def __forward( |
| 458 | self, |
| 459 | inputs: Union[dict, list, torch.Tensor], |
| 460 | executing_order: List[Operation], |
| 461 | output_names:List[str] = None, |
| 462 | hooks: Dict[str, RuntimeHook] = None, |
| 463 | ) -> List[torch.Tensor]: |
| 464 | # processing with different input format |
| 465 | if isinstance(inputs, dict): |
| 466 | # directly feed value into variables |
| 467 | for name, value in inputs.items(): |
| 468 | if name in self._graph.variables: |
| 469 | var = self._graph.variables[name] |
| 470 | var.value = value |
| 471 | else: |
| 472 | print(f'Can not find variable {name} in your graph, please check.') |
| 473 | else: |
| 474 | inputs = self.prepare_input(inputs=inputs) |
| 475 | for key, value in inputs.items(): |
| 476 | assert isinstance(value, torch.Tensor), \ |
| 477 | f'TorchExecutor can only accept tensor as its input, while {type(value)} was given' |
| 478 | # input is acceptable, feed input value |
| 479 | self._graph_input_dictionary[key].value = value |
| 480 | |
| 481 | # processing with output |
| 482 | last_idx = 0 # record last variable |
| 483 | if output_names is None: |
| 484 | output_names = [name for name in self._graph.outputs] |
| 485 | for name in output_names: |
| 486 | if name not in self._graph.variables: |
| 487 | raise KeyError(f'You are requiring output value of variable {name}(is not a variable name), ' |
| 488 | 'however it is not a valid variable of current graph.') |
| 489 | source_op = self._graph.variables[name].source_op |
| 490 | if source_op is not None: |
| 491 | last_idx = max(last_idx, executing_order.index(source_op) + 1) |
| 492 | |
| 493 | visited_op, result_collector = [], [None for _ in output_names] |
| 494 | # output name can be the same as input name, collect them directly. |
| 495 | for name in output_names: |
| 496 | if name in inputs: |
| 497 | result_collector[output_names.index(name)] = inputs[name] |
| 498 | |
| 499 | for operation in executing_order[: last_idx]: |
| 500 | try: |
| 501 | assert isinstance(operation, Operation), 'Oops, seems you got something weird in your graph' |
| 502 | assert isinstance(operation.platform, TargetPlatform), ( |
| 503 | f'Operation {operation.name} has an invalid platform setting, ' |
| 504 | f'only PPQ.core.TargetPlatform is expected here, while {type(operation.platform)} was given') |
| 505 | platform_dispatching_table = OPERATION_FORWARD_TABLE[operation.platform] |
| 506 | if operation.type not in platform_dispatching_table: |
| 507 | raise NotImplementedError( |
| 508 | f'Graph op: {operation.name}({operation.type}) ' |
| 509 | f'has no backend implementation on target platform {operation.platform}. ' |
| 510 | 'Register this op to ppq.executor.base.py and ppq.executor.op first') |
| 511 | operation_forward_func = platform_dispatching_table[operation.type] |
| 512 | operation_runtime_hook = hooks[operation.name] if (hooks is not None) and (operation.name in hooks) else None |
| 513 | inputs = [var.value for var in operation.inputs] |
| 514 |
no test coverage detected