| 257 | |
| 258 | |
| 259 | class QuantableGraph(GraphCommandProcessor): |
| 260 | def process(self, command: GraphCommand) -> Any: |
| 261 | if command.command_type == GraphCommandType.QUANTIZE_OPERATION: |
| 262 | assert isinstance(command, QuantizeOperationCommand) |
| 263 | return self.quantize_operation( |
| 264 | command.op_name, command.target_platform, command.config) |
| 265 | |
| 266 | def _acceptable_command_types(self) -> List[GraphCommandType]: |
| 267 | return [ |
| 268 | GraphCommandType.QUANTIZE_OPERATION, |
| 269 | ] |
| 270 | |
| 271 | def quantize_operation( |
| 272 | self, |
| 273 | operation_name: str, |
| 274 | target_platform: TargetPlatform, |
| 275 | quantization_config: OperationQuantizationConfig |
| 276 | ) -> QuantableOperation: |
| 277 | if operation_name not in self.graph.operations: |
| 278 | raise KeyError(f'Operation {operation_name} is not in your graph, Please check your input.') |
| 279 | |
| 280 | operation = self._graph.operations[operation_name] |
| 281 | quantized_operation = QuantableOperation( |
| 282 | convert_from=operation, |
| 283 | quantize_config=quantization_config, |
| 284 | platform=target_platform, |
| 285 | ) |
| 286 | |
| 287 | # calling other chain responder to replace operation with quantized one. |
| 288 | if self._next_command_processor is None: |
| 289 | raise RuntimeError( |
| 290 | 'To replace a operation, your processor chain must have a GraphReplacer Processor.') |
| 291 | self._next_command_processor(ReplaceOperationCommand(operation_name, quantized_operation)) |
| 292 | |
| 293 | # replace all related variable with quantable one. |
| 294 | for var in quantized_operation.inputs + quantized_operation.outputs: |
| 295 | if isinstance(var, QuantableVariable): continue |
| 296 | self._next_command_processor( |
| 297 | ReplaceVariableCommand( |
| 298 | var_name=var.name, |
| 299 | replace_to=QuantableVariable(convert_from=var) |
| 300 | ) |
| 301 | ) |
| 302 | quantized_operation.store_parameter_value() |
| 303 | |
| 304 | def dequantize_operation( |
| 305 | self, |
| 306 | operation_name: str |
| 307 | ) -> Operation: |
| 308 | if operation_name not in self.graph.operations: |
| 309 | raise KeyError(f'Operation {operation_name} is not in your graph, Please check your input.') |
| 310 | operation = self._graph.operations[operation_name] |
| 311 | if not isinstance(operation, QuantableOperation): return operation |
| 312 | else: return operation.dequantize() |
| 313 | |
| 314 | def dequantize_graph(self, expire_device: str = 'cpu'): |
| 315 | """一个方便懒人的函数.""" |
| 316 | for operation in self.graph.operations.values(): |
no outgoing calls
no test coverage detected