MCPcopy
hub / github.com/OpenPPL/ppq / TorchMetaDataTracingHook

Class TorchMetaDataTracingHook

ppq/executor/torch.py:16–40  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

14
15
16class TorchMetaDataTracingHook(RuntimeHook):
17 def __init__(self, operation: Operation) -> None:
18 super().__init__(operation)
19
20 def pre_forward_hook(self, inputs: List[torch.Tensor], **kwargs) -> list:
21 # some operations got none as its input
22 # therefore we have to create meta for those none input value manually.
23 for tensor, var in zip(inputs, self._hook_to.inputs):
24 if tensor is None:
25 ppq_warning(
26 f'Unexpected input value of operation {self._hook_to.name}, '
27 f'recieving "None" at its input {self._hook_to.inputs.index(var)}')
28 else:
29 var.shape = tensor.shape
30 var.dtype = tensor.dtype
31
32 return inputs
33
34 def post_forward_hook(self, outputs: List[torch.Tensor], **kwargs) -> list:
35 for tensor, var in zip(outputs, self._hook_to.outputs):
36 if tensor is not None:
37 var.shape = tensor.shape
38 var.dtype = tensor.dtype
39
40 return outputs
41
42
43class TorchQuantizeDelegator(Callable):

Callers 1

Calls

no outgoing calls

Tested by

no test coverage detected