Parameters: input : Tensor (On GPU) The input tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be fp4 input_sf : Tensor (On GPU) The input scaling factor tensor. Its shape is [batch_size, se
(input: Tensor,
input_sf: Tensor,
weight: Tensor,
weight_sf: Tensor,
global_sf: Tensor,
output_dtype: str | trt.DataType,
scaling_vector_size: int = 16)
| 1319 | |
| 1320 | |
| 1321 | def fp4_gemm(input: Tensor, |
| 1322 | input_sf: Tensor, |
| 1323 | weight: Tensor, |
| 1324 | weight_sf: Tensor, |
| 1325 | global_sf: Tensor, |
| 1326 | output_dtype: str | trt.DataType, |
| 1327 | scaling_vector_size: int = 16): |
| 1328 | ''' |
| 1329 | Parameters: |
| 1330 | input : Tensor (On GPU) |
| 1331 | The input tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be fp4 |
| 1332 | input_sf : Tensor (On GPU) |
| 1333 | The input scaling factor tensor. Its shape is [batch_size, seq_len, input_dim / scaling_vector_size] or [num_tokens, input_dim / scaling_vector_size] for remove_input_padding, should be int32 (4 packed) |
| 1334 | weight : Tensor (On GPU) |
| 1335 | The weight tensor. Its shape is [output_dim, input_dim], should be fp4 |
| 1336 | weight_sf : Tensor (On GPU) |
| 1337 | The weight scaling factor tensor. Its shape is [output_dim, input_dim / scaling_vector_size], should be fp8 |
| 1338 | global_sf : Tensor (On GPU) |
| 1339 | The global scaling factor tensor. Its shape is [1,], should be float32, used as alpha of Gemm. |
| 1340 | output_dtype: str |
| 1341 | output data type |
| 1342 | scaling_vector_size: int |
| 1343 | scaling vector block size |
| 1344 | ''' |
| 1345 | if isinstance(output_dtype, str): |
| 1346 | output_dtype = str_dtype_to_trt(output_dtype) |
| 1347 | |
| 1348 | fp4_gemm_plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 1349 | 'Fp4Gemm', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 1350 | assert fp4_gemm_plg_creator is not None |
| 1351 | sv_vec_size = trt.PluginField("sv_vec_size", |
| 1352 | np.array(scaling_vector_size, dtype=np.int32), |
| 1353 | trt.PluginFieldType.INT32) |
| 1354 | output_dtype = trt.PluginField("output_type_id", |
| 1355 | np.array([int(output_dtype)], np.int32), |
| 1356 | trt.PluginFieldType.INT32) |
| 1357 | pfc = trt.PluginFieldCollection([sv_vec_size, output_dtype]) |
| 1358 | fp4_gemm_plug = fp4_gemm_plg_creator.create_plugin("fp4_gemm", pfc) |
| 1359 | plug_inputs = [input, input_sf, weight, weight_sf, global_sf] |
| 1360 | plug_inputs = [i.trt_tensor for i in plug_inputs] |
| 1361 | layer = default_trtnet().add_plugin_v2(plug_inputs, fp4_gemm_plug) |
| 1362 | _add_plugin_info(layer, fp4_gemm_plg_creator, "fp4_gemm", pfc) |
| 1363 | output = _create_tensor(layer.get_output(0), layer) |
| 1364 | return output |
| 1365 | |
| 1366 | |
| 1367 | def quantize_to_fp4_tensor(input: Tensor, sf_scale: Tensor): |
no test coverage detected