(self,
processor: BaseGraph,
dataloader: Iterable,
executor: BaseGraphExecutor,
**kwargs)
| 335 | return all([platform == platforms[0] for platform in platforms]) |
| 336 | |
| 337 | def optimize(self, |
| 338 | processor: BaseGraph, |
| 339 | dataloader: Iterable, |
| 340 | executor: BaseGraphExecutor, |
| 341 | **kwargs) -> None: |
| 342 | |
| 343 | def ep_expr(operation: Operation): |
| 344 | if not isinstance(operation, QuantableOperation): return False |
| 345 | if operation.type == 'Conv': return True |
| 346 | if operation.type in PPLCUDA_ACTIVATIONS: |
| 347 | upstream_ops = graph.get_upstream_operations(operation=operation) |
| 348 | if len(upstream_ops) == 0 and upstream_ops[0].type == 'Conv': return True |
| 349 | if upstream_ops[0] in merged: return True |
| 350 | return False |
| 351 | |
| 352 | def retrospect(operation: QuantableOperation) -> QuantableOperation: |
| 353 | if not isinstance(operation, QuantableOperation): return None |
| 354 | if len(graph.get_upstream_operations(operation)) != 1: return None |
| 355 | |
| 356 | parent = graph.get_upstream_operations(operation)[0] |
| 357 | if parent.type != 'Conv': return None |
| 358 | if not isinstance(parent, QuantableOperation): return None |
| 359 | return parent |
| 360 | |
| 361 | def merge_fn(operation: QuantableOperation): |
| 362 | assert isinstance(operation, QuantableOperation) and operation.type == 'Add' |
| 363 | # check if upstream ops can be merged |
| 364 | up_ops = graph.get_upstream_operations(operation) |
| 365 | if not self.is_same_platform(up_ops + [operation]): return |
| 366 | |
| 367 | # Conv - Add - Relu Merge |
| 368 | config = operation.config.output_quantization_config[0] |
| 369 | |
| 370 | # Step - 1: merge add output to next activation. |
| 371 | down_ops = graph.get_downstream_operations(operation) |
| 372 | if (len(down_ops) == 1 and |
| 373 | down_ops[0].type in PPLCUDA_ACTIVATIONS and |
| 374 | isinstance(down_ops[0], QuantableOperation) and |
| 375 | down_ops[0].platform == operation.platform): |
| 376 | config.dominated_by = down_ops[0].config.output_quantization_config[0] |
| 377 | |
| 378 | # Step - 2: disable input conv's quantization(only one). |
| 379 | up_ops = graph.get_upstream_operations(operation) |
| 380 | assert len(up_ops) == 2, f'Opeartion {operation.name} should has exact 2 input operations.' |
| 381 | |
| 382 | target_operation = None |
| 383 | for op in up_ops: |
| 384 | if op.type == 'Conv': |
| 385 | target_operation = op |
| 386 | elif op.type in PPLCUDA_ACTIVATIONS: |
| 387 | target_operation = retrospect(operation) |
| 388 | if target_operation is not None: |
| 389 | break |
| 390 | |
| 391 | if target_operation is not None: |
| 392 | target_operation.config.output_quantization_config[0].dominated_by = config |
| 393 | |
| 394 | graph, merged, unchanged = processor.graph, set(), False |
nothing calls this directly
no test coverage detected