MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / quantize

Function quantize

tensorrt_llm/models/gemma/convert.py:382–405  ·  view source on GitHub ↗
(param: np.ndarray,
             quant_mode: tensorrt_llm.quantization.QuantMode)

Source from the content-addressed store, hash-verified

380
381
382def quantize(param: np.ndarray,
383 quant_mode: tensorrt_llm.quantization.QuantMode):
384 if quant_mode.is_int8_weight_only():
385 quant_dtype = torch.int8
386 elif quant_mode.is_int4_weight_only():
387 quant_dtype = torch.quint4x2
388 else:
389 raise ValueError(f"Invalid configuration got quant_mode={quant_mode}")
390
391 param = numpy_to_torch(param)
392 param = param.t().contiguous()
393
394 # previously this fn was available in torch.ops.fastertransformer namespace
395 (
396 quantized_weights,
397 scales,
398 ) = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
399 param, quant_dtype)
400
401 if scales.dtype == torch.bfloat16:
402 scales = scales.to(torch.float32).numpy().astype("bfloat16")
403 else:
404 scales = scales.numpy()
405 return quantized_weights.numpy(), scales
406
407
408Parsers = Union[JAXParser, KerasParser, TorchParser, HfParser]

Callers 1

load_gemma_weightsFunction · 0.70

Calls 4

numpy_to_torchFunction · 0.90
is_int8_weight_onlyMethod · 0.80
is_int4_weight_onlyMethod · 0.80
toMethod · 0.45

Tested by

no test coverage detected