MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / __init__

Method __init__

tensorrt_llm/quantization/layers.py:704–759  ·  view source on GitHub ↗
(
        self,
        in_features,
        out_features,
        bias=True,
        dtype=None,
        tp_group=None,
        tp_size=1,
        tp_rank=0,
        gather_output=True,
        quant_mode=QuantMode.use_weight_only(),
        transa=False,
        transb=False,
        is_qkv=False,
        prefer_managed_weight=True,
    )

Source from the content-addressed store, hash-verified

702class WeightOnlyQuantLinear(Linear):
703
704 def __init__(
705 self,
706 in_features,
707 out_features,
708 bias=True,
709 dtype=None,
710 tp_group=None,
711 tp_size=1,
712 tp_rank=0,
713 gather_output=True,
714 quant_mode=QuantMode.use_weight_only(),
715 transa=False,
716 transb=False,
717 is_qkv=False,
718 prefer_managed_weight=True,
719 ):
720 multiple = 64 * tp_size
721 self.is_padded = False
722 if in_features % multiple > 0:
723 in_features = math.ceil(in_features / multiple) * multiple
724 self.is_padded = True
725 if out_features % multiple > 0:
726 out_features = math.ceil(out_features / multiple) * multiple
727 self.is_padded = True
728
729 super().__init__(in_features,
730 out_features,
731 bias=bias,
732 dtype=dtype,
733 tp_group=tp_group,
734 tp_size=tp_size,
735 gather_output=gather_output,
736 is_qkv=is_qkv,
737 prefer_managed_weight=prefer_managed_weight)
738 if quant_mode.is_int8_weight_only():
739 self.weight_only_quant_mode = 1
740 quant_type_size_in_bits = 8
741 elif quant_mode.is_int4_weight_only():
742 self.weight_only_quant_mode = 2
743 quant_type_size_in_bits = 4
744 # we use a fake tensor with data_type = int8
745 self.weight = Parameter(shape=(self.in_features,
746 int(self.out_features *
747 quant_type_size_in_bits / 8)),
748 dtype="int8",
749 prefer_managed=self.prefer_managed_weight)
750
751 scale_shape = (self.out_features, )
752 self.per_channel_scale = Parameter(shape=scale_shape, dtype=dtype)
753
754 self.transa = transa
755 self.transb = transb
756 self.tp_rank = tp_rank
757 if self.is_padded:
758 self.tp_dim = -1
759 self.quant_mode = quant_mode
760
761 def forward(self, x, lora_runtime_params=None):

Callers

nothing calls this directly

Calls 5

ParameterClass · 0.85
use_weight_onlyMethod · 0.80
is_int8_weight_onlyMethod · 0.80
is_int4_weight_onlyMethod · 0.80
__init__Method · 0.45

Tested by

no test coverage detected