MCPcopy
hub / github.com/OpenPPL/ppq / GraphMerger

Class GraphMerger

ppq/IR/morph.py:501–1075  ·  view source on GitHub ↗

Graph Merger implements all graph fusion related functions.

Source from the content-addressed store, hash-verified

499
500
501class GraphMerger(GraphCommandProcessor):
502 """Graph Merger implements all graph fusion related functions."""
503 def _acceptable_command_types(self) -> List[GraphCommandType]:
504 return [
505 # add more extensions in the future
506 GraphCommandType.FUSE_BN,
507 GraphCommandType.FUSE_BIAS_ADD
508 ]
509
510 def process(self, command: GraphCommand) -> Any:
511 if command.command_type == GraphCommandType.FUSE_BN:
512 return self.fuse_bn()
513 if command.command_type == GraphCommandType.FUSE_BIAS_ADD:
514 return self.fuse_bias_add()
515
516
517 def fuse_bn(self):
518 search_engine = SearchableGraph(graph=self.graph)
519 paths = search_engine.path_matching(
520 sp_expr=lambda x: x.type in {'Conv', 'Gemm', 'ConvTranspose'},
521 rp_expr=lambda x, y: False,
522 ep_expr=lambda x: x.type == 'BatchNormalization',
523 direction='down')
524
525 for path in paths:
526 path = path.tolist()
527 assert len(path) == 2, ('Oops seems we got something unexpected.')
528
529 computing_op, bn_op = path
530 assert isinstance(computing_op, Operation) and isinstance(bn_op, Operation)
531
532 if (len(self.graph.get_downstream_operations(computing_op)) != 1 or
533 len(self.graph.get_upstream_operations(bn_op)) != 1):
534 ppq_warning(f'PPQ can not merge operation {computing_op.name} and {bn_op.name}, '
535 'this is not suppose to happen with your network, '
536 'network with batchnorm inside might not be able to quantize and deploy.')
537 continue
538
539 assert len(bn_op.parameters) == 4, 'BatchNorm should have 4 parameters, namely alpha, beta, mean, var'
540 alpha = bn_op.parameters[0].value
541 beta = bn_op.parameters[1].value
542 mean = bn_op.parameters[2].value
543 var = bn_op.parameters[3].value
544 epsilon = bn_op.attributes.get('epsilon', 1e-5)
545
546 if computing_op.num_of_parameter == 1:
547 w = computing_op.parameters[0].value # no bias.
548 assert isinstance(w, torch.Tensor), 'values of parameters are assumed as torch Tensor'
549 if computing_op.type == 'ConvTranspose':
550 b = torch.zeros(w.shape[1] * computing_op.attributes.get('group', 1))
551 elif computing_op.type == 'Gemm' and computing_op.attributes.get('transB', 0) == 0:
552 b = torch.zeros(w.shape[1])
553 else:
554 b = torch.zeros(w.shape[0])
555 else:
556 w, b = [var.value for var in computing_op.parameters[: 2]] # has bias.
557
558 if computing_op.type == 'Conv':

Callers 5

infer_shapeMethod · 0.90
format_graphFunction · 0.90
testFuseBias.pyFile · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected