Parameters: input : Tensor (On GPU) The input tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be fp16 sf_scale : Tensor (On GPU) The global per-tensor scaling factor. Its shape is [1,], sho
(input: Tensor, sf_scale: Tensor)
| 1365 | |
| 1366 | |
| 1367 | def quantize_to_fp4_tensor(input: Tensor, sf_scale: Tensor): |
| 1368 | ''' |
| 1369 | Parameters: |
| 1370 | input : Tensor (On GPU) |
| 1371 | The input tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be fp16 |
| 1372 | sf_scale : Tensor (On GPU) |
| 1373 | The global per-tensor scaling factor. Its shape is [1,], should be float32. |
| 1374 | used to scale SF from input range to fp8 range (448.f / (MaxVal of input / 6.f)). |
| 1375 | output : Tensor (On GPU) |
| 1376 | The output tensor. Its shape is [batch_size, seq_len, input_dim] or [num_tokens, input_dim] for remove_input_padding, should be FP4 |
| 1377 | output_sf : Tensor (On GPU) |
| 1378 | The input scaling factor tensor. Its shape is [batch_size, seq_len, input_dim / scaling_vector_size] or [num_tokens, input_dim / scaling_vector_size] for remove_input_padding, should be FP8 |
| 1379 | ''' |
| 1380 | plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 1381 | 'QuantizeToFP4', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 1382 | assert plg_creator is not None |
| 1383 | |
| 1384 | pfc = trt.PluginFieldCollection([]) |
| 1385 | quantize_plug = plg_creator.create_plugin("quantize_to_fp4_plugin", pfc) |
| 1386 | |
| 1387 | plug_inputs = [input.trt_tensor, sf_scale.trt_tensor] |
| 1388 | layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug) |
| 1389 | _add_plugin_info(layer, plg_creator, "quantize_to_fp4_plugin", pfc) |
| 1390 | |
| 1391 | quantized = _create_tensor(layer.get_output(0), layer) |
| 1392 | scales = _create_tensor(layer.get_output(1), layer) |
| 1393 | return quantized, scales |
| 1394 | |
| 1395 | |
| 1396 | def dynamic_quantize( |
no test coverage detected