(x: Tensor,
clamp_val: Optional[Tensor] = None)
| 859 | |
| 860 | |
| 861 | def quantize_fp8_per_token(x: Tensor, |
| 862 | clamp_val: Optional[Tensor] = None) -> Tuple[Tensor]: |
| 863 | if not default_net().plugin_config.quantize_per_token_plugin: |
| 864 | x = cast(x, 'float32') |
| 865 | xmax = x.abs().max(-1, keepdim=True) |
| 866 | scale = xmax / 448.0 |
| 867 | out = x * 448.0 / xmax |
| 868 | out = round(out) |
| 869 | out = clip(out, -448, 448) |
| 870 | quantized_out = cast(out, 'fp8') |
| 871 | return quantized_out, scale |
| 872 | else: |
| 873 | plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 874 | 'QuantizePerToken', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 875 | assert plg_creator is not None |
| 876 | |
| 877 | output_type = trt.PluginField("type_id", |
| 878 | np.array([int(trt.fp8)], np.int32), |
| 879 | trt.PluginFieldType.INT32) |
| 880 | quant_mode = trt.PluginField( |
| 881 | "quant_mode", |
| 882 | np.array([int(QuantMode.from_description(use_fp8_rowwise=True))], |
| 883 | np.int32), trt.PluginFieldType.INT32) |
| 884 | clamp_enabled = trt.PluginField( |
| 885 | "clamp_enabled", np.array([clamp_val is not None], np.int8), |
| 886 | trt.PluginFieldType.INT8) |
| 887 | sum_per_token_pf = trt.PluginField("sum_per_token", |
| 888 | np.array([int(False)], np.int32), |
| 889 | trt.PluginFieldType.INT32) |
| 890 | pfc = trt.PluginFieldCollection( |
| 891 | [output_type, quant_mode, clamp_enabled, sum_per_token_pf]) |
| 892 | quantize_plug = plg_creator.create_plugin("quantize_per_token_plugin", |
| 893 | pfc) |
| 894 | |
| 895 | plug_inputs = [x.trt_tensor] |
| 896 | if clamp_val: |
| 897 | plug_inputs += [clamp_val.trt_tensor] |
| 898 | layer = default_trtnet().add_plugin_v2(plug_inputs, quantize_plug) |
| 899 | if not default_net().strongly_typed: |
| 900 | layer.get_output(0).set_dynamic_range(-448, 448) |
| 901 | _add_plugin_info(layer, plg_creator, "quantize_per_token_plugin", pfc) |
| 902 | |
| 903 | quantized = _create_tensor(layer.get_output(0), layer) |
| 904 | scales = _create_tensor(layer.get_output(1), layer) |
| 905 | |
| 906 | return quantized, scales |
| 907 | |
| 908 | |
| 909 | def quantize_tensor(x, scale): |
no test coverage detected