(
self,
output_dtype: str = 'int8',
scaling_factor_dtype: str = 'float32',
in_channels: int = -1,
axis=-1,
)
| 56 | """ |
| 57 | |
| 58 | def __init__( |
| 59 | self, |
| 60 | output_dtype: str = 'int8', |
| 61 | scaling_factor_dtype: str = 'float32', |
| 62 | in_channels: int = -1, |
| 63 | axis=-1, |
| 64 | ) -> None: |
| 65 | super().__init__() |
| 66 | self.scaling_factor = Parameter(shape=(in_channels, ) if axis != -1 else |
| 67 | (), |
| 68 | dtype=scaling_factor_dtype) |
| 69 | self.output_dtype = output_dtype |
| 70 | self.axis = axis |
| 71 | |
| 72 | def forward(self, x): |
| 73 | return quantize(x, self.scaling_factor.value, self.output_dtype, |
no test coverage detected