MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / smooth_quant_gemm

Function smooth_quant_gemm

tensorrt_llm/quantization/functional.py:34–96  ·  view source on GitHub ↗
(input: Tensor, weights: Tensor, scales_a: Tensor,
                      scales_b: Tensor, per_token_scaling: bool,
                      per_channel_scaling: bool, dtype: str)

Source from the content-addressed store, hash-verified

32
33
34def smooth_quant_gemm(input: Tensor, weights: Tensor, scales_a: Tensor,
35 scales_b: Tensor, per_token_scaling: bool,
36 per_channel_scaling: bool, dtype: str) -> Tensor:
37 if not default_net().plugin_config.smooth_quant_gemm_plugin:
38 if per_token_scaling and input.size(0) == -1:
39 # WAR for DQ per-token scaling doesn't support dynamic shapes
40
41 scale_one = constant(np.array(1.0, dtype=np.float32))
42 input = dequantize(input, scale_one, 0, 'float32')
43 weights = dequantize(weights, scale_one, 0, 'float32')
44 result = matmul(input, weights, False, True, False)
45 scales = matmul(scales_a, scales_b, False, False, False)
46 result = result * scales
47 result = cast(result, dtype)
48 return result
49 else:
50 if not per_token_scaling:
51 scales_a = view(scales_a, [])
52 else:
53 scales_a = flatten(scales_a)
54 if not per_channel_scaling:
55 scales_b = view(scales_b, [])
56 else:
57 scales_b = flatten(scales_b)
58 input = dequantize(input, scales_a, 0, dtype)
59 weights = dequantize(weights, scales_b, 0, dtype)
60 result = matmul(input, weights, False, True, False)
61 return result
62 else:
63 plg_creator = trt.get_plugin_registry().get_plugin_creator(
64 'SmoothQuantGemm', '1', TRT_LLM_PLUGIN_NAMESPACE)
65 assert plg_creator is not None
66
67 per_channel_scaling = 1 if per_channel_scaling else 0
68 per_channel_scaling = trt.PluginField(
69 "has_per_channel_scaling",
70 np.array(per_channel_scaling, dtype=np.int32),
71 trt.PluginFieldType.INT32)
72
73 per_token_scaling = 1 if per_token_scaling else 0
74 per_token_scaling = trt.PluginField(
75 "has_per_token_scaling", np.array(per_token_scaling,
76 dtype=np.int32),
77 trt.PluginFieldType.INT32)
78
79 p_dtype = default_net().plugin_config.smooth_quant_gemm_plugin
80 pf_type = trt.PluginField(
81 "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
82 trt.PluginFieldType.INT32)
83
84 pfc = trt.PluginFieldCollection(
85 [per_channel_scaling, per_token_scaling, pf_type])
86 gemm_plug = plg_creator.create_plugin("sq_gemm", pfc)
87 plug_inputs = [
88 input.trt_tensor, weights.trt_tensor, scales_a.trt_tensor,
89 scales_b.trt_tensor
90 ]
91 layer = default_trtnet().add_plugin_v2(plug_inputs, gemm_plug)

Callers 3

_sq_gemmMethod · 0.90
forwardMethod · 0.85
forwardMethod · 0.85

Calls 15

default_netFunction · 0.85
constantFunction · 0.85
dequantizeFunction · 0.85
matmulFunction · 0.85
castFunction · 0.85
viewFunction · 0.85
flattenFunction · 0.85
str_dtype_to_trtFunction · 0.85
default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
create_pluginMethod · 0.80

Tested by 1

_sq_gemmMethod · 0.72