MCPcopy
hub / github.com/OpenPPL/ppq / __forward

Method __forward

ppq/executor/torch.py:457–577  ·  view source on GitHub ↗
(
        self,
        inputs: Union[dict, list, torch.Tensor],
        executing_order: List[Operation],
        output_names:List[str] = None,
        hooks: Dict[str, RuntimeHook] = None,
    )

Source from the content-addressed store, hash-verified

455 )
456
457 def __forward(
458 self,
459 inputs: Union[dict, list, torch.Tensor],
460 executing_order: List[Operation],
461 output_names:List[str] = None,
462 hooks: Dict[str, RuntimeHook] = None,
463 ) -> List[torch.Tensor]:
464 # processing with different input format
465 if isinstance(inputs, dict):
466 # directly feed value into variables
467 for name, value in inputs.items():
468 if name in self._graph.variables:
469 var = self._graph.variables[name]
470 var.value = value
471 else:
472 print(f'Can not find variable {name} in your graph, please check.')
473 else:
474 inputs = self.prepare_input(inputs=inputs)
475 for key, value in inputs.items():
476 assert isinstance(value, torch.Tensor), \
477 f'TorchExecutor can only accept tensor as its input, while {type(value)} was given'
478 # input is acceptable, feed input value
479 self._graph_input_dictionary[key].value = value
480
481 # processing with output
482 last_idx = 0 # record last variable
483 if output_names is None:
484 output_names = [name for name in self._graph.outputs]
485 for name in output_names:
486 if name not in self._graph.variables:
487 raise KeyError(f'You are requiring output value of variable {name}(is not a variable name), '
488 'however it is not a valid variable of current graph.')
489 source_op = self._graph.variables[name].source_op
490 if source_op is not None:
491 last_idx = max(last_idx, executing_order.index(source_op) + 1)
492
493 visited_op, result_collector = [], [None for _ in output_names]
494 # output name can be the same as input name, collect them directly.
495 for name in output_names:
496 if name in inputs:
497 result_collector[output_names.index(name)] = inputs[name]
498
499 for operation in executing_order[: last_idx]:
500 try:
501 assert isinstance(operation, Operation), 'Oops, seems you got something weird in your graph'
502 assert isinstance(operation.platform, TargetPlatform), (
503 f'Operation {operation.name} has an invalid platform setting, '
504 f'only PPQ.core.TargetPlatform is expected here, while {type(operation.platform)} was given')
505 platform_dispatching_table = OPERATION_FORWARD_TABLE[operation.platform]
506 if operation.type not in platform_dispatching_table:
507 raise NotImplementedError(
508 f'Graph op: {operation.name}({operation.type}) '
509 f'has no backend implementation on target platform {operation.platform}. '
510 'Register this op to ppq.executor.base.py and ppq.executor.op first')
511 operation_forward_func = platform_dispatching_table[operation.type]
512 operation_runtime_hook = hooks[operation.name] if (hooks is not None) and (operation.name in hooks) else None
513 inputs = [var.value for var in operation.inputs]
514

Callers 4

forwardMethod · 0.95
forward_with_gradientMethod · 0.95
partial_graph_forwardMethod · 0.95

Calls 5

quantize_functionMethod · 0.95
prepare_inputMethod · 0.80
pre_forward_hookMethod · 0.45
post_forward_hookMethod · 0.45
appendMethod · 0.45

Tested by

no test coverage detected