(ctx, tensor: torch.Tensor, config: TensorQuantizationConfig)
| 142 | """ |
| 143 | @ staticmethod |
| 144 | def forward(ctx, tensor: torch.Tensor, config: TensorQuantizationConfig) -> torch.Tensor: |
| 145 | from ppq.quantization.observer.range import minmax_to_scale_offset |
| 146 | |
| 147 | channelwise_view = tensor.transpose(dim0=0, dim1=config.channel_axis).unsqueeze(-1) |
| 148 | channelwise_view = torch.flatten(channelwise_view, start_dim=1) |
| 149 | |
| 150 | scales, offsets = [], [] |
| 151 | for _min, _max in zip( |
| 152 | channelwise_view.min(dim=1)[0].tolist(), |
| 153 | channelwise_view.max(dim=1)[0].tolist() |
| 154 | ): |
| 155 | s, o = minmax_to_scale_offset(_min, _max, config) |
| 156 | scales.append(s) |
| 157 | offsets.append(o) |
| 158 | |
| 159 | scales = torch.tensor(scales, dtype=torch.float32, device=tensor.device) |
| 160 | offsets = torch.tensor(offsets, dtype=torch.float32, device=tensor.device) |
| 161 | |
| 162 | # generate a shape that likes [1, 1, -1, 1], the only -1 is at channel axe. |
| 163 | shape = [1 if axis != config.channel_axis else -1 for axis in range(tensor.ndim)] |
| 164 | scales, offsets = scales.view(shape), offsets.view(shape) |
| 165 | |
| 166 | tensor = ppq_tensor_round((tensor / scales), config.rounding) + offsets |
| 167 | tensor = torch.clamp(tensor, config.quant_min, config.quant_max) |
| 168 | tensor = (tensor - offsets) * scales |
| 169 | return tensor |
| 170 | |
| 171 | @ staticmethod |
| 172 | def backward(ctx, dy: torch.Tensor): |
nothing calls this directly
no test coverage detected