(
self,
graph: BaseGraph,
**kwargs
)
| 181 | |
| 182 | @ empty_ppq_cache |
| 183 | def optimize( |
| 184 | self, |
| 185 | graph: BaseGraph, |
| 186 | **kwargs |
| 187 | ) -> None: |
| 188 | processor = SearchableGraph(graph) |
| 189 | |
| 190 | # fuse computing operations and its following activation. |
| 191 | if self.fuse_activation: |
| 192 | patterns = processor.pattern_matching( |
| 193 | patterns=[lambda x: x.is_computing_op, lambda x: x.type in self.activation_types], |
| 194 | edges=[[0, 1]], exclusive=True) |
| 195 | |
| 196 | for computing_op, act_op in patterns: |
| 197 | if not isinstance(act_op, QuantableOperation): continue |
| 198 | if not isinstance(computing_op, QuantableOperation): continue |
| 199 | |
| 200 | if (computing_op.platform != act_op.platform and |
| 201 | computing_op.config.output_quantization_config[0].state != QuantizationStates.FP32): |
| 202 | ppq_warning(f'Unexpected dispatching was found: ' |
| 203 | f'Op {computing_op.name} and {act_op.name} should be send to a same platform.') |
| 204 | continue |
| 205 | |
| 206 | if (len(graph.get_downstream_operations(computing_op)) == 1 and |
| 207 | len(graph.get_upstream_operations(act_op)) == 1): |
| 208 | computing_op.config.output_quantization_config[0].dominated_by = ( |
| 209 | act_op.config.output_quantization_config[0]) |
| 210 | act_op.config.input_quantization_config[0].dominated_by = ( |
| 211 | act_op.config.output_quantization_config[0]) |
| 212 | |
| 213 | if 'Swish' in self.activation_types: |
| 214 | search_engine = SearchableGraph(graph) |
| 215 | patterns = search_engine.pattern_matching( |
| 216 | patterns = [lambda x: x.is_computing_op, 'Sigmoid', 'Mul'], |
| 217 | edges = [[0, 1], [1, 2], [0, 2]], |
| 218 | exclusive = True) |
| 219 | |
| 220 | for pattern in patterns: |
| 221 | if any([not isinstance(op, QuantableOperation) for op in pattern]): |
| 222 | ppq_warning(f'There is a pattern of swish activation in your network start from {pattern[0]}, ' |
| 223 | 'however part of your swish activation is not quantable, ' |
| 224 | 'so that graph fusion can not merge their quantization configuration.') |
| 225 | continue |
| 226 | if any([op.platform != pattern[0].platform for op in pattern]): |
| 227 | ppq_warning(f'There is a pattern of swish activation in your network start from {pattern[0]}, ' |
| 228 | 'however part of your swish activation is not quantable, ' |
| 229 | 'so that graph fusion can not merge their quantization configuration.') |
| 230 | continue |
| 231 | computing, sigmoid, mul = pattern |
| 232 | |
| 233 | assert isinstance(computing, QuantableOperation) |
| 234 | assert isinstance(sigmoid, QuantableOperation) |
| 235 | assert isinstance(mul, QuantableOperation) |
| 236 | |
| 237 | master_config = mul.config.output_quantization_config[0] |
| 238 | computing.config.output_quantization_config[0].dominated_by = master_config |
| 239 | sigmoid.config.input_quantization_config[0].dominated_by = master_config |
| 240 | sigmoid.config.output_quantization_config[0].dominated_by = master_config |
nothing calls this directly
no test coverage detected