(
self,
normalized_shape,
eps=1e-06,
elementwise_affine=True,
dtype=None,
quant_mode=QuantMode(0),
bias=False,
clamp_val=None,
)
| 298 | class SmoothQuantRmsNorm(Module): |
| 299 | |
| 300 | def __init__( |
| 301 | self, |
| 302 | normalized_shape, |
| 303 | eps=1e-06, |
| 304 | elementwise_affine=True, |
| 305 | dtype=None, |
| 306 | quant_mode=QuantMode(0), |
| 307 | bias=False, |
| 308 | clamp_val=None, |
| 309 | ): |
| 310 | super().__init__() |
| 311 | if isinstance(normalized_shape, int): |
| 312 | normalized_shape = (normalized_shape, ) |
| 313 | if not quant_mode.has_act_and_weight_quant(): |
| 314 | raise ValueError( |
| 315 | "SmoothQuant Rms norm has to have some quantization mode set") |
| 316 | self.normalized_shape = tuple(normalized_shape) |
| 317 | self.elementwise_affine = elementwise_affine |
| 318 | if self.elementwise_affine: |
| 319 | self.weight = Parameter(shape=self.normalized_shape, dtype=dtype) |
| 320 | else: |
| 321 | self.register_parameter('weight', None) |
| 322 | |
| 323 | if bias: |
| 324 | self.bias = Parameter(shape=self.normalized_shape, dtype=dtype) |
| 325 | else: |
| 326 | self.register_parameter('bias', None) |
| 327 | if clamp_val: |
| 328 | if not (isinstance(clamp_val, list) and len(clamp_val) == 2): |
| 329 | raise ValueError(f'unsupported clamp_val {clamp_val}') |
| 330 | self.clamp_val = Parameter(np.array(clamp_val, dtype=np.float32), |
| 331 | dtype='float32', |
| 332 | is_buffer=True) |
| 333 | else: |
| 334 | self.register_parameter('clamp_val', None) |
| 335 | |
| 336 | self.eps = eps |
| 337 | self.dtype = dtype |
| 338 | self.quant_mode = quant_mode |
| 339 | |
| 340 | if self.quant_mode.has_act_and_weight_quant(): |
| 341 | self.scale_to_int = Parameter(shape=(1, ), dtype=dtype) |
| 342 | else: |
| 343 | self.register_parameter('scale_to_int', None) |
| 344 | |
| 345 | def forward(self, x): |
| 346 | weight = None if self.weight is None else self.weight.value |
nothing calls this directly
no test coverage detected