(
self,
normalized_shape,
eps=1e-06,
elementwise_affine=True,
dtype=None,
quant_mode=QuantMode(0),
bias=False,
clamp_val=None,
)
| 512 | class Fp8RowwiseRmsNorm(Module): |
| 513 | |
| 514 | def __init__( |
| 515 | self, |
| 516 | normalized_shape, |
| 517 | eps=1e-06, |
| 518 | elementwise_affine=True, |
| 519 | dtype=None, |
| 520 | quant_mode=QuantMode(0), |
| 521 | bias=False, |
| 522 | clamp_val=None, |
| 523 | ): |
| 524 | super().__init__() |
| 525 | if isinstance(normalized_shape, int): |
| 526 | normalized_shape = (normalized_shape, ) |
| 527 | if not quant_mode.has_fp8_rowwise(): |
| 528 | raise ValueError( |
| 529 | "Fp8 Rowwise Rms norm has to have some quantization mode set") |
| 530 | self.normalized_shape = tuple(normalized_shape) |
| 531 | self.elementwise_affine = elementwise_affine |
| 532 | if self.elementwise_affine: |
| 533 | self.weight = Parameter(shape=self.normalized_shape, dtype=dtype) |
| 534 | else: |
| 535 | self.register_parameter('weight', None) |
| 536 | |
| 537 | if bias: |
| 538 | self.bias = Parameter(shape=self.normalized_shape, dtype=dtype) |
| 539 | else: |
| 540 | self.register_parameter('bias', None) |
| 541 | |
| 542 | if clamp_val: |
| 543 | if not (isinstance(clamp_val, list) and len(clamp_val) == 2): |
| 544 | raise ValueError(f'unsupported clamp_val {clamp_val}') |
| 545 | self.clamp_val = Parameter(np.array(clamp_val, dtype=np.float32), |
| 546 | dtype='float32', |
| 547 | is_buffer=True) |
| 548 | else: |
| 549 | self.register_parameter('clamp_val', None) |
| 550 | |
| 551 | self.eps = eps |
| 552 | self.dtype = dtype |
| 553 | self.quant_mode = quant_mode |
| 554 | |
| 555 | def forward(self, x): |
| 556 | weight = None if self.weight is None else self.weight.value |
nothing calls this directly
no test coverage detected