Replace Batchnorm to 1D Convolution.
(self, dimension: int = 2)
| 81 | ] |
| 82 | |
| 83 | def replace_batchnorm_to_conv(self, dimension: int = 2): |
| 84 | """ Replace Batchnorm to 1D Convolution. """ |
| 85 | for op in self.graph.operations.values(): |
| 86 | if op.type == 'BatchNormalization': |
| 87 | ppq_warning(f'Isolated BatchNormalization({op.name}) was detected, ' |
| 88 | f'PPQ will replace it to 1*1 Convolution({dimension}D).') |
| 89 | |
| 90 | assert len(op.parameters) == 4, "BatchNorm should have 4 parameters, namely alpha, beta, mean, var" |
| 91 | alpha = op.parameters[0].value |
| 92 | beta = op.parameters[1].value |
| 93 | mean = op.parameters[2].value |
| 94 | var = op.parameters[3].value |
| 95 | epsilon = op.attributes.get("epsilon", 1e-5) |
| 96 | |
| 97 | with torch.no_grad(): |
| 98 | w = alpha / torch.sqrt(var + epsilon) |
| 99 | w = w.reshape([-1, 1] + [1] * dimension) |
| 100 | b = alpha * (-mean) / torch.sqrt(var + epsilon) + beta |
| 101 | |
| 102 | op.type = 'Conv' |
| 103 | op.attributes.clear() |
| 104 | op.attributes['kernel_shape'] = [1] * dimension |
| 105 | op.attributes['strides'] = [1] * dimension |
| 106 | op.attributes['dilations'] = [1] * dimension |
| 107 | op.attributes['pads'] = [0, 0] * dimension |
| 108 | op.attributes['group'] = w.numel() |
| 109 | |
| 110 | # remove last 2 variable, make conv has exact 3 input |
| 111 | self.graph.remove_variable(op.inputs[-1]) |
| 112 | self.graph.remove_variable(op.inputs[-1]) |
| 113 | |
| 114 | with torch.no_grad(): |
| 115 | op.inputs[1].value = w |
| 116 | op.inputs[2].value = b |
| 117 | |
| 118 | def replace_batchnorm_to_scale(self, dimension: int = 4): |
| 119 | """ Replace Batchnorm to Mul + Add. |
no test coverage detected