MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / __init__

Method __init__

tensorrt_llm/quantization/layers.py:102–143  ·  view source on GitHub ↗
(self,
                 in_features,
                 out_features,
                 bias=True,
                 dtype=None,
                 tp_group=None,
                 tp_size=1,
                 gather_output=True,
                 quant_mode=QuantMode(0),
                 prefer_managed_weight=True)

Source from the content-addressed store, hash-verified

100class SmoothQuantLinear(Linear):
101
102 def __init__(self,
103 in_features,
104 out_features,
105 bias=True,
106 dtype=None,
107 tp_group=None,
108 tp_size=1,
109 gather_output=True,
110 quant_mode=QuantMode(0),
111 prefer_managed_weight=True):
112 super().__init__(in_features,
113 out_features,
114 bias=bias,
115 dtype=dtype,
116 tp_group=tp_group,
117 tp_size=tp_size,
118 gather_output=gather_output,
119 prefer_managed_weight=prefer_managed_weight)
120
121 if not quant_mode.has_act_and_weight_quant():
122 raise ValueError(
123 "SmoothQuant Linear has to have act+weight quantization mode set"
124 )
125
126 weights_dtype = dtype
127 if quant_mode.has_act_and_weight_quant():
128 weights_dtype = "int8"
129
130 self.weight = Parameter(shape=(self.out_features, self.in_features),
131 dtype=weights_dtype,
132 prefer_managed=self.prefer_managed_weight)
133
134 if quant_mode.has_act_and_weight_quant():
135 scale_shape = (1, self.out_features
136 ) if quant_mode.has_per_channel_scaling() else (1, 1)
137 self.per_channel_scale = Parameter(shape=scale_shape,
138 dtype="float32")
139
140 if quant_mode.has_act_static_scaling():
141 self.act_scale = Parameter(shape=(1, 1), dtype="float32")
142
143 self.quant_mode = quant_mode
144
145 def forward(self, x, lora_runtime_params=None):
146 assert lora_runtime_params is None, "lora is not supported on SmoothQuantLinear now"

Callers

nothing calls this directly

Calls 6

QuantModeClass · 0.85
ParameterClass · 0.85
__init__Method · 0.45

Tested by

no test coverage detected