(param: np.ndarray,
quant_mode: tensorrt_llm.quantization.QuantMode)
| 380 | |
| 381 | |
| 382 | def quantize(param: np.ndarray, |
| 383 | quant_mode: tensorrt_llm.quantization.QuantMode): |
| 384 | if quant_mode.is_int8_weight_only(): |
| 385 | quant_dtype = torch.int8 |
| 386 | elif quant_mode.is_int4_weight_only(): |
| 387 | quant_dtype = torch.quint4x2 |
| 388 | else: |
| 389 | raise ValueError(f"Invalid configuration got quant_mode={quant_mode}") |
| 390 | |
| 391 | param = numpy_to_torch(param) |
| 392 | param = param.t().contiguous() |
| 393 | |
| 394 | # previously this fn was available in torch.ops.fastertransformer namespace |
| 395 | ( |
| 396 | quantized_weights, |
| 397 | scales, |
| 398 | ) = torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( |
| 399 | param, quant_dtype) |
| 400 | |
| 401 | if scales.dtype == torch.bfloat16: |
| 402 | scales = scales.to(torch.float32).numpy().astype("bfloat16") |
| 403 | else: |
| 404 | scales = scales.numpy() |
| 405 | return quantized_weights.numpy(), scales |
| 406 | |
| 407 | |
| 408 | Parsers = Union[JAXParser, KerasParser, TorchParser, HfParser] |
no test coverage detected