MCPcopy
hub / github.com/OpenPPL/ppq / replace_batchnorm_to_scale

Method replace_batchnorm_to_scale

ppq/IR/morph.py:118–152  ·  view source on GitHub ↗

Replace Batchnorm to Mul + Add. By default this function created a 4d mul + add corresponding to NCHW layout.

(self, dimension: int = 4)

Source from the content-addressed store, hash-verified

116 op.inputs[2].value = b
117
118 def replace_batchnorm_to_scale(self, dimension: int = 4):
119 """ Replace Batchnorm to Mul + Add.
120
121 By default this function created a 4d mul + add corresponding to NCHW layout.
122 """
123 graph = self.graph
124 for op in [_ for _ in self.graph.operations.values()]:
125
126 if op.type == 'BatchNormalization':
127 ppq_warning(f'Isolated BatchNormalization({op.name}) was detected, '
128 f'PPQ will replace it to Mul + Add({dimension}D).')
129
130 assert len(op.parameters) == 4, "BatchNorm should have 4 parameters, namely alpha, beta, mean, var"
131 alpha = op.parameters[0].value
132 beta = op.parameters[1].value
133 mean = op.parameters[2].value
134 var = op.parameters[3].value
135 epsilon = op.attributes.get("epsilon", 1e-5)
136
137 with torch.no_grad():
138 multiplier = alpha / torch.sqrt(var + epsilon)
139 bias = (-mean) * multiplier + beta
140
141 for var in [_ for _ in op.parameters]: graph.remove_variable(var)
142 graph.create_variable(value=multiplier, is_parameter=True, dest_ops=[op])
143 op.type = 'Mul'
144 op.attributes.clear()
145
146 add = graph.create_operation(op_type='Add')
147 graph.insert_op_after(A=add, B=op)
148 graph.create_variable(value=bias, is_parameter=True, dest_ops=[add])
149
150 if dimension > 1:
151 op.parameters[0].value = op.parameters[0].value.reshape([1, -1] + [1] * (dimension - 2))
152 add.parameters[0].value = add.parameters[0].value.reshape([1, -1] + [1] * (dimension - 2))
153
154
155class GraphFormatter(GraphCommandProcessor):

Callers 2

processMethod · 0.95
testBnToConv.pyFile · 0.80

Calls 6

ppq_warningFunction · 0.90
remove_variableMethod · 0.80
create_variableMethod · 0.80
create_operationMethod · 0.80
insert_op_afterMethod · 0.80
clearMethod · 0.45

Tested by

no test coverage detected