MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / fp4_gemm

Function fp4_gemm

tensorrt_llm/quantization/functional.py:1321–1364  ·  view source on GitHub ↗

Parameters: input : Tensor (On GPU) The input tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be fp4 input_sf : Tensor (On GPU) The input scaling factor tensor. Its shape is [batch_size, se

(input: Tensor,
             input_sf: Tensor,
             weight: Tensor,
             weight_sf: Tensor,
             global_sf: Tensor,
             output_dtype: str | trt.DataType,
             scaling_vector_size: int = 16)

Source from the content-addressed store, hash-verified

1319
1320
1321def fp4_gemm(input: Tensor,
1322 input_sf: Tensor,
1323 weight: Tensor,
1324 weight_sf: Tensor,
1325 global_sf: Tensor,
1326 output_dtype: str | trt.DataType,
1327 scaling_vector_size: int = 16):
1328 '''
1329 Parameters:
1330 input : Tensor (On GPU)
1331 The input tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be fp4
1332 input_sf : Tensor (On GPU)
1333 The input scaling factor tensor. Its shape is [batch_size, seq_len, input_dim / scaling_vector_size] or [num_tokens, input_dim / scaling_vector_size] for remove_input_padding, should be int32 (4 packed)
1334 weight : Tensor (On GPU)
1335 The weight tensor. Its shape is [output_dim, input_dim], should be fp4
1336 weight_sf : Tensor (On GPU)
1337 The weight scaling factor tensor. Its shape is [output_dim, input_dim / scaling_vector_size], should be fp8
1338 global_sf : Tensor (On GPU)
1339 The global scaling factor tensor. Its shape is [1,], should be float32, used as alpha of Gemm.
1340 output_dtype: str
1341 output data type
1342 scaling_vector_size: int
1343 scaling vector block size
1344 '''
1345 if isinstance(output_dtype, str):
1346 output_dtype = str_dtype_to_trt(output_dtype)
1347
1348 fp4_gemm_plg_creator = trt.get_plugin_registry().get_plugin_creator(
1349 'Fp4Gemm', '1', TRT_LLM_PLUGIN_NAMESPACE)
1350 assert fp4_gemm_plg_creator is not None
1351 sv_vec_size = trt.PluginField("sv_vec_size",
1352 np.array(scaling_vector_size, dtype=np.int32),
1353 trt.PluginFieldType.INT32)
1354 output_dtype = trt.PluginField("output_type_id",
1355 np.array([int(output_dtype)], np.int32),
1356 trt.PluginFieldType.INT32)
1357 pfc = trt.PluginFieldCollection([sv_vec_size, output_dtype])
1358 fp4_gemm_plug = fp4_gemm_plg_creator.create_plugin("fp4_gemm", pfc)
1359 plug_inputs = [input, input_sf, weight, weight_sf, global_sf]
1360 plug_inputs = [i.trt_tensor for i in plug_inputs]
1361 layer = default_trtnet().add_plugin_v2(plug_inputs, fp4_gemm_plug)
1362 _add_plugin_info(layer, fp4_gemm_plg_creator, "fp4_gemm", pfc)
1363 output = _create_tensor(layer.get_output(0), layer)
1364 return output
1365
1366
1367def quantize_to_fp4_tensor(input: Tensor, sf_scale: Tensor):

Callers 2

forwardMethod · 0.85
forwardMethod · 0.85

Calls 6

str_dtype_to_trtFunction · 0.85
default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
create_pluginMethod · 0.80
get_outputMethod · 0.45

Tested by

no test coverage detected