(
self,
in_features,
out_features,
bias=True,
dtype=None,
tp_group=None,
tp_size=1,
tp_rank=0,
gather_output=True,
quant_mode=QuantMode.use_weight_only(),
transa=False,
transb=False,
is_qkv=False,
prefer_managed_weight=True,
)
| 702 | class WeightOnlyQuantLinear(Linear): |
| 703 | |
| 704 | def __init__( |
| 705 | self, |
| 706 | in_features, |
| 707 | out_features, |
| 708 | bias=True, |
| 709 | dtype=None, |
| 710 | tp_group=None, |
| 711 | tp_size=1, |
| 712 | tp_rank=0, |
| 713 | gather_output=True, |
| 714 | quant_mode=QuantMode.use_weight_only(), |
| 715 | transa=False, |
| 716 | transb=False, |
| 717 | is_qkv=False, |
| 718 | prefer_managed_weight=True, |
| 719 | ): |
| 720 | multiple = 64 * tp_size |
| 721 | self.is_padded = False |
| 722 | if in_features % multiple > 0: |
| 723 | in_features = math.ceil(in_features / multiple) * multiple |
| 724 | self.is_padded = True |
| 725 | if out_features % multiple > 0: |
| 726 | out_features = math.ceil(out_features / multiple) * multiple |
| 727 | self.is_padded = True |
| 728 | |
| 729 | super().__init__(in_features, |
| 730 | out_features, |
| 731 | bias=bias, |
| 732 | dtype=dtype, |
| 733 | tp_group=tp_group, |
| 734 | tp_size=tp_size, |
| 735 | gather_output=gather_output, |
| 736 | is_qkv=is_qkv, |
| 737 | prefer_managed_weight=prefer_managed_weight) |
| 738 | if quant_mode.is_int8_weight_only(): |
| 739 | self.weight_only_quant_mode = 1 |
| 740 | quant_type_size_in_bits = 8 |
| 741 | elif quant_mode.is_int4_weight_only(): |
| 742 | self.weight_only_quant_mode = 2 |
| 743 | quant_type_size_in_bits = 4 |
| 744 | # we use a fake tensor with data_type = int8 |
| 745 | self.weight = Parameter(shape=(self.in_features, |
| 746 | int(self.out_features * |
| 747 | quant_type_size_in_bits / 8)), |
| 748 | dtype="int8", |
| 749 | prefer_managed=self.prefer_managed_weight) |
| 750 | |
| 751 | scale_shape = (self.out_features, ) |
| 752 | self.per_channel_scale = Parameter(shape=scale_shape, dtype=dtype) |
| 753 | |
| 754 | self.transa = transa |
| 755 | self.transb = transb |
| 756 | self.tp_rank = tp_rank |
| 757 | if self.is_padded: |
| 758 | self.tp_dim = -1 |
| 759 | self.quant_mode = quant_mode |
| 760 | |
| 761 | def forward(self, x, lora_runtime_params=None): |
nothing calls this directly
no test coverage detected