(
self,
in_features,
out_features,
bias=True,
dtype=None,
tp_group=None,
tp_size=1,
quant_mode=QuantMode(0),
)
| 434 | class QServeW4A8RowLinear(RowLinear): |
| 435 | |
| 436 | def __init__( |
| 437 | self, |
| 438 | in_features, |
| 439 | out_features, |
| 440 | bias=True, |
| 441 | dtype=None, |
| 442 | tp_group=None, |
| 443 | tp_size=1, |
| 444 | quant_mode=QuantMode(0), |
| 445 | ): |
| 446 | assert dtype == "float16" # Currently the kernel only supports float16 output |
| 447 | |
| 448 | super().__init__(in_features, |
| 449 | out_features, |
| 450 | bias=bias, |
| 451 | dtype=dtype, |
| 452 | tp_group=tp_group, |
| 453 | tp_size=tp_size) |
| 454 | |
| 455 | self.quant_mode = quant_mode |
| 456 | assert self.quant_mode.is_qserve_w4a8() |
| 457 | # Only supports 128g now. |
| 458 | if self.quant_mode.has_per_group_scaling(): |
| 459 | self.group_size = 128 |
| 460 | else: |
| 461 | self.group_size = -1 |
| 462 | |
| 463 | self.weight = Parameter(shape=(self.out_features, |
| 464 | self.in_features // 2), |
| 465 | dtype="int8") |
| 466 | |
| 467 | self.s1_scales = Parameter(shape=(self.out_features, ), dtype="float16") |
| 468 | |
| 469 | if self.group_size == -1: |
| 470 | self.s1_szeros = Parameter(shape=(self.out_features, ), |
| 471 | dtype="float16") |
| 472 | else: |
| 473 | self.s2_scales = Parameter( |
| 474 | shape=(self.in_features // self.group_size, self.out_features), |
| 475 | dtype="int8") |
| 476 | self.s2_szeros = Parameter( |
| 477 | shape=(self.in_features // self.group_size, self.out_features), |
| 478 | dtype="int8") |
| 479 | |
| 480 | def forward(self, x, all_reduce_params=None): |
| 481 | if self.group_size == -1: |
nothing calls this directly
no test coverage detected