(
self,
inputs: Union[torch.Tensor, list, dict],
calib_dataloader: Iterable,
executor: BaseGraphExecutor,
setting: QuantizationSetting,
**kwargs
)
| 29 | |
| 30 | @ empty_ppq_cache |
| 31 | def quantize( |
| 32 | self, |
| 33 | inputs: Union[torch.Tensor, list, dict], |
| 34 | calib_dataloader: Iterable, |
| 35 | executor: BaseGraphExecutor, |
| 36 | setting: QuantizationSetting, |
| 37 | **kwargs |
| 38 | ) -> None: |
| 39 | # step - 1, prequant pipeline: |
| 40 | # prequant pipeline will change your network structure and float value. |
| 41 | prequant_pipeline = self.build_prequant_pipeline( |
| 42 | setting, executor=executor) |
| 43 | prequant_pipeline.optimize( |
| 44 | graph=self._graph, |
| 45 | dataloader=calib_dataloader, |
| 46 | executor=executor, |
| 47 | verbose=self._verbose, |
| 48 | **kwargs) |
| 49 | |
| 50 | # step - 2, quantize all operations |
| 51 | executor.load_graph(self._graph) |
| 52 | executor.tracing_operation_meta(inputs=inputs) |
| 53 | |
| 54 | for op_name, operation in self._graph.operations.items(): |
| 55 | if (operation.platform == TargetPlatform.UNSPECIFIED): |
| 56 | if operation.type in self.quant_operation_types: |
| 57 | operation.platform = self.target_platform |
| 58 | else: operation.platform = TargetPlatform.FP32 |
| 59 | |
| 60 | if operation.platform not in {TargetPlatform.FP32, TargetPlatform.SOI}: |
| 61 | self.quantize_operation(op_name) |
| 62 | |
| 63 | # quantize operation will modify network structure |
| 64 | # it is necessary calling self._executor before further execution |
| 65 | # step - 3, calling graph optimization pipeline |
| 66 | executor.load_graph(self._graph) |
| 67 | quant_pipeline = self.build_quant_pipeline(setting) |
| 68 | |
| 69 | quant_pipeline.optimize( |
| 70 | graph=self._graph, |
| 71 | dataloader=calib_dataloader, |
| 72 | executor=executor, |
| 73 | verbose=self._verbose, |
| 74 | **kwargs) |
| 75 | |
| 76 | if self._verbose: |
| 77 | print(self.report(), end='') |
| 78 | print('Network Quantization Finished.') |
| 79 | |
| 80 | def quantize_operation(self, op_name: str, platform: TargetPlatform=None) -> QuantableOperation: |
| 81 | if op_name not in self._graph.operations: |
no test coverage detected