| 14 | |
| 15 | |
| 16 | class TorchMetaDataTracingHook(RuntimeHook): |
| 17 | def __init__(self, operation: Operation) -> None: |
| 18 | super().__init__(operation) |
| 19 | |
| 20 | def pre_forward_hook(self, inputs: List[torch.Tensor], **kwargs) -> list: |
| 21 | # some operations got none as its input |
| 22 | # therefore we have to create meta for those none input value manually. |
| 23 | for tensor, var in zip(inputs, self._hook_to.inputs): |
| 24 | if tensor is None: |
| 25 | ppq_warning( |
| 26 | f'Unexpected input value of operation {self._hook_to.name}, ' |
| 27 | f'recieving "None" at its input {self._hook_to.inputs.index(var)}') |
| 28 | else: |
| 29 | var.shape = tensor.shape |
| 30 | var.dtype = tensor.dtype |
| 31 | |
| 32 | return inputs |
| 33 | |
| 34 | def post_forward_hook(self, outputs: List[torch.Tensor], **kwargs) -> list: |
| 35 | for tensor, var in zip(outputs, self._hook_to.outputs): |
| 36 | if tensor is not None: |
| 37 | var.shape = tensor.shape |
| 38 | var.dtype = tensor.dtype |
| 39 | |
| 40 | return outputs |
| 41 | |
| 42 | |
| 43 | class TorchQuantizeDelegator(Callable): |
no outgoing calls
no test coverage detected