| 32 | |
| 33 | |
| 34 | def smooth_quant_gemm(input: Tensor, weights: Tensor, scales_a: Tensor, |
| 35 | scales_b: Tensor, per_token_scaling: bool, |
| 36 | per_channel_scaling: bool, dtype: str) -> Tensor: |
| 37 | if not default_net().plugin_config.smooth_quant_gemm_plugin: |
| 38 | if per_token_scaling and input.size(0) == -1: |
| 39 | # WAR for DQ per-token scaling doesn't support dynamic shapes |
| 40 | |
| 41 | scale_one = constant(np.array(1.0, dtype=np.float32)) |
| 42 | input = dequantize(input, scale_one, 0, 'float32') |
| 43 | weights = dequantize(weights, scale_one, 0, 'float32') |
| 44 | result = matmul(input, weights, False, True, False) |
| 45 | scales = matmul(scales_a, scales_b, False, False, False) |
| 46 | result = result * scales |
| 47 | result = cast(result, dtype) |
| 48 | return result |
| 49 | else: |
| 50 | if not per_token_scaling: |
| 51 | scales_a = view(scales_a, []) |
| 52 | else: |
| 53 | scales_a = flatten(scales_a) |
| 54 | if not per_channel_scaling: |
| 55 | scales_b = view(scales_b, []) |
| 56 | else: |
| 57 | scales_b = flatten(scales_b) |
| 58 | input = dequantize(input, scales_a, 0, dtype) |
| 59 | weights = dequantize(weights, scales_b, 0, dtype) |
| 60 | result = matmul(input, weights, False, True, False) |
| 61 | return result |
| 62 | else: |
| 63 | plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 64 | 'SmoothQuantGemm', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 65 | assert plg_creator is not None |
| 66 | |
| 67 | per_channel_scaling = 1 if per_channel_scaling else 0 |
| 68 | per_channel_scaling = trt.PluginField( |
| 69 | "has_per_channel_scaling", |
| 70 | np.array(per_channel_scaling, dtype=np.int32), |
| 71 | trt.PluginFieldType.INT32) |
| 72 | |
| 73 | per_token_scaling = 1 if per_token_scaling else 0 |
| 74 | per_token_scaling = trt.PluginField( |
| 75 | "has_per_token_scaling", np.array(per_token_scaling, |
| 76 | dtype=np.int32), |
| 77 | trt.PluginFieldType.INT32) |
| 78 | |
| 79 | p_dtype = default_net().plugin_config.smooth_quant_gemm_plugin |
| 80 | pf_type = trt.PluginField( |
| 81 | "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32), |
| 82 | trt.PluginFieldType.INT32) |
| 83 | |
| 84 | pfc = trt.PluginFieldCollection( |
| 85 | [per_channel_scaling, per_token_scaling, pf_type]) |
| 86 | gemm_plug = plg_creator.create_plugin("sq_gemm", pfc) |
| 87 | plug_inputs = [ |
| 88 | input.trt_tensor, weights.trt_tensor, scales_a.trt_tensor, |
| 89 | scales_b.trt_tensor |
| 90 | ] |
| 91 | layer = default_trtnet().add_plugin_v2(plug_inputs, gemm_plug) |