| 350 | |
| 351 | # TODO: Should be renamed to layer_norm_quantize. |
| 352 | def smooth_quant_layer_norm(input: Tensor, |
| 353 | normalized_shape: Union[int, Tuple[int]], |
| 354 | weight: Optional[Tensor] = None, |
| 355 | bias: Optional[Tensor] = None, |
| 356 | scale: Optional[Tensor] = None, |
| 357 | eps: float = 1e-05, |
| 358 | use_diff_of_squares: bool = True, |
| 359 | dynamic_act_scaling: bool = False) -> Tensor: |
| 360 | if not default_net().plugin_config.layernorm_quantization_plugin: |
| 361 | dtype = trt_dtype_to_np(input.dtype) |
| 362 | if weight is None: |
| 363 | weight = constant(np.ones(normalized_shape, dtype=dtype)) |
| 364 | if bias is None: |
| 365 | bias = constant(np.zeros(normalized_shape, dtype=dtype)) |
| 366 | result = layer_norm(input, normalized_shape, weight, bias, eps, |
| 367 | use_diff_of_squares) |
| 368 | if not dynamic_act_scaling: |
| 369 | return quantize_tensor(result, scale) |
| 370 | else: |
| 371 | return quantize_per_token(result) |
| 372 | else: |
| 373 | plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 374 | 'LayernormQuantization', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 375 | assert plg_creator is not None |
| 376 | |
| 377 | output_type = trt.PluginField("out_type_id", |
| 378 | np.array([int(trt.int8)], np.int32), |
| 379 | trt.PluginFieldType.INT32) |
| 380 | quant_mode = trt.PluginField( |
| 381 | "quant_mode", |
| 382 | np.array([int(QuantMode.use_smooth_quant(per_token=True))], |
| 383 | np.int32), trt.PluginFieldType.INT32) |
| 384 | eps = trt.PluginField("eps", np.array(eps, dtype=np.float32), |
| 385 | trt.PluginFieldType.FLOAT32) |
| 386 | use_diff_of_squares = trt.PluginField( |
| 387 | "use_diff_of_squares", |
| 388 | np.array([int(use_diff_of_squares)], dtype=np.int32), |
| 389 | trt.PluginFieldType.INT32) |
| 390 | |
| 391 | dyn_act_scaling = trt.PluginField( |
| 392 | "dyn_act_scaling", np.array([int(dynamic_act_scaling)], np.int32), |
| 393 | trt.PluginFieldType.INT32) |
| 394 | |
| 395 | p_dtype = default_net().plugin_config.layernorm_quantization_plugin |
| 396 | pf_type = trt.PluginField( |
| 397 | "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32), |
| 398 | trt.PluginFieldType.INT32) |
| 399 | pfc = trt.PluginFieldCollection([ |
| 400 | eps, use_diff_of_squares, dyn_act_scaling, pf_type, output_type, |
| 401 | quant_mode |
| 402 | ]) |
| 403 | layernorm_plug = plg_creator.create_plugin("layernorm_quantized", pfc) |
| 404 | normalized_shape = [normalized_shape] if isinstance( |
| 405 | normalized_shape, int) else normalized_shape |
| 406 | if weight is None: |
| 407 | weight = constant( |
| 408 | np.ones(normalized_shape, dtype=str_dtype_to_np(p_dtype))) |
| 409 | if bias is None: |