MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / __init__

Method __init__

tensorrt_llm/quantization/layers.py:300–343  ·  view source on GitHub ↗
(
            self,
            normalized_shape,
            eps=1e-06,
            elementwise_affine=True,
            dtype=None,
            quant_mode=QuantMode(0),
            bias=False,
            clamp_val=None,
    )

Source from the content-addressed store, hash-verified

298class SmoothQuantRmsNorm(Module):
299
300 def __init__(
301 self,
302 normalized_shape,
303 eps=1e-06,
304 elementwise_affine=True,
305 dtype=None,
306 quant_mode=QuantMode(0),
307 bias=False,
308 clamp_val=None,
309 ):
310 super().__init__()
311 if isinstance(normalized_shape, int):
312 normalized_shape = (normalized_shape, )
313 if not quant_mode.has_act_and_weight_quant():
314 raise ValueError(
315 "SmoothQuant Rms norm has to have some quantization mode set")
316 self.normalized_shape = tuple(normalized_shape)
317 self.elementwise_affine = elementwise_affine
318 if self.elementwise_affine:
319 self.weight = Parameter(shape=self.normalized_shape, dtype=dtype)
320 else:
321 self.register_parameter('weight', None)
322
323 if bias:
324 self.bias = Parameter(shape=self.normalized_shape, dtype=dtype)
325 else:
326 self.register_parameter('bias', None)
327 if clamp_val:
328 if not (isinstance(clamp_val, list) and len(clamp_val) == 2):
329 raise ValueError(f'unsupported clamp_val {clamp_val}')
330 self.clamp_val = Parameter(np.array(clamp_val, dtype=np.float32),
331 dtype='float32',
332 is_buffer=True)
333 else:
334 self.register_parameter('clamp_val', None)
335
336 self.eps = eps
337 self.dtype = dtype
338 self.quant_mode = quant_mode
339
340 if self.quant_mode.has_act_and_weight_quant():
341 self.scale_to_int = Parameter(shape=(1, ), dtype=dtype)
342 else:
343 self.register_parameter('scale_to_int', None)
344
345 def forward(self, x):
346 weight = None if self.weight is None else self.weight.value

Callers

nothing calls this directly

Calls 5

QuantModeClass · 0.85
ParameterClass · 0.85
__init__Method · 0.45
register_parameterMethod · 0.45

Tested by

no test coverage detected