MCPcopy
hub / github.com/OpenPPL/ppq / fuse_bn

Method fuse_bn

ppq/IR/morph.py:517–615  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

515
516
517 def fuse_bn(self):
518 search_engine = SearchableGraph(graph=self.graph)
519 paths = search_engine.path_matching(
520 sp_expr=lambda x: x.type in {'Conv', 'Gemm', 'ConvTranspose'},
521 rp_expr=lambda x, y: False,
522 ep_expr=lambda x: x.type == 'BatchNormalization',
523 direction='down')
524
525 for path in paths:
526 path = path.tolist()
527 assert len(path) == 2, ('Oops seems we got something unexpected.')
528
529 computing_op, bn_op = path
530 assert isinstance(computing_op, Operation) and isinstance(bn_op, Operation)
531
532 if (len(self.graph.get_downstream_operations(computing_op)) != 1 or
533 len(self.graph.get_upstream_operations(bn_op)) != 1):
534 ppq_warning(f'PPQ can not merge operation {computing_op.name} and {bn_op.name}, '
535 'this is not suppose to happen with your network, '
536 'network with batchnorm inside might not be able to quantize and deploy.')
537 continue
538
539 assert len(bn_op.parameters) == 4, 'BatchNorm should have 4 parameters, namely alpha, beta, mean, var'
540 alpha = bn_op.parameters[0].value
541 beta = bn_op.parameters[1].value
542 mean = bn_op.parameters[2].value
543 var = bn_op.parameters[3].value
544 epsilon = bn_op.attributes.get('epsilon', 1e-5)
545
546 if computing_op.num_of_parameter == 1:
547 w = computing_op.parameters[0].value # no bias.
548 assert isinstance(w, torch.Tensor), 'values of parameters are assumed as torch Tensor'
549 if computing_op.type == 'ConvTranspose':
550 b = torch.zeros(w.shape[1] * computing_op.attributes.get('group', 1))
551 elif computing_op.type == 'Gemm' and computing_op.attributes.get('transB', 0) == 0:
552 b = torch.zeros(w.shape[1])
553 else:
554 b = torch.zeros(w.shape[0])
555 else:
556 w, b = [var.value for var in computing_op.parameters[: 2]] # has bias.
557
558 if computing_op.type == 'Conv':
559
560 # calculate new weight and bias
561 scale = alpha / torch.sqrt(var + epsilon)
562 w = w * scale.reshape([-1] + [1] * (w.ndim - 1))
563 b = alpha * (b - mean) / torch.sqrt(var + epsilon) + beta
564
565 elif computing_op.type == 'Gemm':
566
567 # calculate new weight and bias
568 scale = alpha / torch.sqrt(var + epsilon)
569 if computing_op.attributes.get('transB', 0):
570 w = w * scale.reshape([-1, 1])
571 else:
572 w = w * scale.reshape([1, -1])
573 b = alpha * (b - mean) / torch.sqrt(var + epsilon) + beta
574

Callers 1

processMethod · 0.95

Calls 15

path_matchingMethod · 0.95
SearchableGraphClass · 0.90
ppq_warningFunction · 0.90
OperationClass · 0.85
VariableClass · 0.85
tolistMethod · 0.80
remove_operationMethod · 0.80
append_operationMethod · 0.80
append_variableMethod · 0.80
copyMethod · 0.45

Tested by

no test coverage detected