MCPcopy
hub / github.com/OpenPPL/ppq / replace_batchnorm_to_conv

Method replace_batchnorm_to_conv

ppq/IR/morph.py:83–116  ·  view source on GitHub ↗

Replace Batchnorm to 1D Convolution.

(self, dimension: int = 2)

Source from the content-addressed store, hash-verified

81 ]
82
83 def replace_batchnorm_to_conv(self, dimension: int = 2):
84 """ Replace Batchnorm to 1D Convolution. """
85 for op in self.graph.operations.values():
86 if op.type == 'BatchNormalization':
87 ppq_warning(f'Isolated BatchNormalization({op.name}) was detected, '
88 f'PPQ will replace it to 1*1 Convolution({dimension}D).')
89
90 assert len(op.parameters) == 4, "BatchNorm should have 4 parameters, namely alpha, beta, mean, var"
91 alpha = op.parameters[0].value
92 beta = op.parameters[1].value
93 mean = op.parameters[2].value
94 var = op.parameters[3].value
95 epsilon = op.attributes.get("epsilon", 1e-5)
96
97 with torch.no_grad():
98 w = alpha / torch.sqrt(var + epsilon)
99 w = w.reshape([-1, 1] + [1] * dimension)
100 b = alpha * (-mean) / torch.sqrt(var + epsilon) + beta
101
102 op.type = 'Conv'
103 op.attributes.clear()
104 op.attributes['kernel_shape'] = [1] * dimension
105 op.attributes['strides'] = [1] * dimension
106 op.attributes['dilations'] = [1] * dimension
107 op.attributes['pads'] = [0, 0] * dimension
108 op.attributes['group'] = w.numel()
109
110 # remove last 2 variable, make conv has exact 3 input
111 self.graph.remove_variable(op.inputs[-1])
112 self.graph.remove_variable(op.inputs[-1])
113
114 with torch.no_grad():
115 op.inputs[1].value = w
116 op.inputs[2].value = b
117
118 def replace_batchnorm_to_scale(self, dimension: int = 4):
119 """ Replace Batchnorm to Mul + Add.

Callers 2

processMethod · 0.95
testBnToConv.pyFile · 0.80

Calls 3

ppq_warningFunction · 0.90
remove_variableMethod · 0.80
clearMethod · 0.45

Tested by

no test coverage detected