(self, tensor: torch.Tensor, config: TensorQuantizationConfig)
| 120 | self.var.value.copy_(self.param_backup) |
| 121 | |
| 122 | def __call__(self, tensor: torch.Tensor, config: TensorQuantizationConfig) -> torch.Tensor: |
| 123 | scale = config.scale |
| 124 | offset = config.offset |
| 125 | if config.policy.has_property(QuantizationProperty.PER_CHANNEL): |
| 126 | shape = [1 if axis != config.channel_axis else -1 for axis in range(tensor.ndim)] |
| 127 | scale = scale.view(shape) |
| 128 | offset = offset.view(shape) |
| 129 | tensor = (tensor / scale).floor() + self.reg.rectified_sigmoid(self.rounding) |
| 130 | tensor = torch.clamp(tensor + offset, config.quant_min, config.quant_max) |
| 131 | tensor = (tensor - offset) * scale |
| 132 | return tensor |
| 133 | |
| 134 | def regularization_loss(self, step: int) -> torch.Tensor: |
| 135 | return self.reg.forward(r=self.rounding, iter=step) |
nothing calls this directly
no test coverage detected