Replace Batchnorm to Mul + Add. By default this function created a 4d mul + add corresponding to NCHW layout.
(self, dimension: int = 4)
| 116 | op.inputs[2].value = b |
| 117 | |
| 118 | def replace_batchnorm_to_scale(self, dimension: int = 4): |
| 119 | """ Replace Batchnorm to Mul + Add. |
| 120 | |
| 121 | By default this function created a 4d mul + add corresponding to NCHW layout. |
| 122 | """ |
| 123 | graph = self.graph |
| 124 | for op in [_ for _ in self.graph.operations.values()]: |
| 125 | |
| 126 | if op.type == 'BatchNormalization': |
| 127 | ppq_warning(f'Isolated BatchNormalization({op.name}) was detected, ' |
| 128 | f'PPQ will replace it to Mul + Add({dimension}D).') |
| 129 | |
| 130 | assert len(op.parameters) == 4, "BatchNorm should have 4 parameters, namely alpha, beta, mean, var" |
| 131 | alpha = op.parameters[0].value |
| 132 | beta = op.parameters[1].value |
| 133 | mean = op.parameters[2].value |
| 134 | var = op.parameters[3].value |
| 135 | epsilon = op.attributes.get("epsilon", 1e-5) |
| 136 | |
| 137 | with torch.no_grad(): |
| 138 | multiplier = alpha / torch.sqrt(var + epsilon) |
| 139 | bias = (-mean) * multiplier + beta |
| 140 | |
| 141 | for var in [_ for _ in op.parameters]: graph.remove_variable(var) |
| 142 | graph.create_variable(value=multiplier, is_parameter=True, dest_ops=[op]) |
| 143 | op.type = 'Mul' |
| 144 | op.attributes.clear() |
| 145 | |
| 146 | add = graph.create_operation(op_type='Add') |
| 147 | graph.insert_op_after(A=add, B=op) |
| 148 | graph.create_variable(value=bias, is_parameter=True, dest_ops=[add]) |
| 149 | |
| 150 | if dimension > 1: |
| 151 | op.parameters[0].value = op.parameters[0].value.reshape([1, -1] + [1] * (dimension - 2)) |
| 152 | add.parameters[0].value = add.parameters[0].value.reshape([1, -1] + [1] * (dimension - 2)) |
| 153 | |
| 154 | |
| 155 | class GraphFormatter(GraphCommandProcessor): |
no test coverage detected