(self)
| 105 | return tensors |
| 106 | |
| 107 | def finalize(self) -> None: |
| 108 | weight, scale, offset = self.var.value, self.config.scale, self.config.offset |
| 109 | if self.config.policy.has_property(QuantizationProperty.PER_CHANNEL): |
| 110 | shape = [1 if axis != self.config.channel_axis else -1 for axis in range(weight.ndim)] |
| 111 | scale = scale.view(shape) |
| 112 | offset = offset.view(shape) |
| 113 | weight = (weight / scale).floor() + (self.rounding >= 0).float() |
| 114 | weight = torch.clamp(weight + offset, self.config.quant_min, self.config.quant_max) |
| 115 | weight = (weight - offset) * scale |
| 116 | self.var.value = weight |
| 117 | |
| 118 | def withdraw(self) -> None: |
| 119 | with torch.no_grad(): |
no test coverage detected