MCPcopy Index your code
hub / github.com/NVIDIA/TensorRT-LLM / quantize_to_fp4_tensor

Function quantize_to_fp4_tensor

tensorrt_llm/quantization/functional.py:1367–1393  ·  view source on GitHub ↗

Parameters: input : Tensor (On GPU) The input tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be fp16 sf_scale : Tensor (On GPU) The global per-tensor scaling factor. Its shape is [1,], sho

(input: Tensor, sf_scale: Tensor)

Source from the content-addressed store, hash-verified

1365
1366
1367def quantize_to_fp4_tensor(input: Tensor, sf_scale: Tensor):
1368 '''
1369 Parameters:
1370 input : Tensor (On GPU)
1371 The input tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be fp16
1372 sf_scale : Tensor (On GPU)
1373 The global per-tensor scaling factor. Its shape is [1,], should be float32.
1374 used to scale SF from input range to fp8 range (448.f / (MaxVal of input / 6.f)).
1375 output : Tensor (On GPU)
1376 The output tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be FP4
1377 output_sf : Tensor (On GPU)
1378 The input scaling factor tensor. Its shape is [batch_size, seq_len, input_dim / scaling_vector_size] or [num_tokens, input_dim / scaling_vector_size] for remove_input_padding, should be FP8
1379 '''
1380 plg_creator = trt.get_plugin_registry().get_plugin_creator(
1381 'QuantizeToFP4', '1', TRT_LLM_PLUGIN_NAMESPACE)
1382 assert plg_creator is not None
1383
1384 pfc = trt.PluginFieldCollection([])
1385 quantize_plug = plg_creator.create_plugin("quantize_to_fp4_plugin", pfc)
1386
1387 plug_inputs = [input.trt_tensor, sf_scale.trt_tensor]
1388 layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug)
1389 _add_plugin_info(layer, plg_creator, "quantize_to_fp4_plugin", pfc)
1390
1391 quantized = _create_tensor(layer.get_output(0), layer)
1392 scales = _create_tensor(layer.get_output(1), layer)
1393 return quantized, scales
1394
1395
1396def dynamic_quantize(

Callers 2

forwardMethod · 0.85
forwardMethod · 0.85

Calls 5

default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
create_pluginMethod · 0.80
get_outputMethod · 0.45

Tested by

no test coverage detected