Graph Merger implements all graph fusion related functions.
| 499 | |
| 500 | |
| 501 | class GraphMerger(GraphCommandProcessor): |
| 502 | """Graph Merger implements all graph fusion related functions.""" |
| 503 | def _acceptable_command_types(self) -> List[GraphCommandType]: |
| 504 | return [ |
| 505 | # add more extensions in the future |
| 506 | GraphCommandType.FUSE_BN, |
| 507 | GraphCommandType.FUSE_BIAS_ADD |
| 508 | ] |
| 509 | |
| 510 | def process(self, command: GraphCommand) -> Any: |
| 511 | if command.command_type == GraphCommandType.FUSE_BN: |
| 512 | return self.fuse_bn() |
| 513 | if command.command_type == GraphCommandType.FUSE_BIAS_ADD: |
| 514 | return self.fuse_bias_add() |
| 515 | |
| 516 | |
| 517 | def fuse_bn(self): |
| 518 | search_engine = SearchableGraph(graph=self.graph) |
| 519 | paths = search_engine.path_matching( |
| 520 | sp_expr=lambda x: x.type in {'Conv', 'Gemm', 'ConvTranspose'}, |
| 521 | rp_expr=lambda x, y: False, |
| 522 | ep_expr=lambda x: x.type == 'BatchNormalization', |
| 523 | direction='down') |
| 524 | |
| 525 | for path in paths: |
| 526 | path = path.tolist() |
| 527 | assert len(path) == 2, ('Oops seems we got something unexpected.') |
| 528 | |
| 529 | computing_op, bn_op = path |
| 530 | assert isinstance(computing_op, Operation) and isinstance(bn_op, Operation) |
| 531 | |
| 532 | if (len(self.graph.get_downstream_operations(computing_op)) != 1 or |
| 533 | len(self.graph.get_upstream_operations(bn_op)) != 1): |
| 534 | ppq_warning(f'PPQ can not merge operation {computing_op.name} and {bn_op.name}, ' |
| 535 | 'this is not suppose to happen with your network, ' |
| 536 | 'network with batchnorm inside might not be able to quantize and deploy.') |
| 537 | continue |
| 538 | |
| 539 | assert len(bn_op.parameters) == 4, 'BatchNorm should have 4 parameters, namely alpha, beta, mean, var' |
| 540 | alpha = bn_op.parameters[0].value |
| 541 | beta = bn_op.parameters[1].value |
| 542 | mean = bn_op.parameters[2].value |
| 543 | var = bn_op.parameters[3].value |
| 544 | epsilon = bn_op.attributes.get('epsilon', 1e-5) |
| 545 | |
| 546 | if computing_op.num_of_parameter == 1: |
| 547 | w = computing_op.parameters[0].value # no bias. |
| 548 | assert isinstance(w, torch.Tensor), 'values of parameters are assumed as torch Tensor' |
| 549 | if computing_op.type == 'ConvTranspose': |
| 550 | b = torch.zeros(w.shape[1] * computing_op.attributes.get('group', 1)) |
| 551 | elif computing_op.type == 'Gemm' and computing_op.attributes.get('transB', 0) == 0: |
| 552 | b = torch.zeros(w.shape[1]) |
| 553 | else: |
| 554 | b = torch.zeros(w.shape[0]) |
| 555 | else: |
| 556 | w, b = [var.value for var in computing_op.parameters[: 2]] # has bias. |
| 557 | |
| 558 | if computing_op.type == 'Conv': |
no outgoing calls
no test coverage detected