MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / weight_only_quant_matmul

Function weight_only_quant_matmul

tensorrt_llm/quantization/functional.py:217–257  ·  view source on GitHub ↗
(input: Tensor,
                             weights: Tensor,
                             scales: Tensor,
                             weightTypeId: int,
                             dtype: str = 'float16',
                             transa: bool = False,
                             transb: bool = False)

Source from the content-addressed store, hash-verified

215
216
217def weight_only_quant_matmul(input: Tensor,
218 weights: Tensor,
219 scales: Tensor,
220 weightTypeId: int,
221 dtype: str = 'float16',
222 transa: bool = False,
223 transb: bool = False) -> Tensor:
224 if not default_net(
225 ).plugin_config.weight_only_quant_matmul_plugin or transa or transb:
226 scale_axis = 0 if transb else 1
227 if weights.dtype != trt.int8:
228 # Q->DQ
229 weights = quantize(weights, scales, dtype='int8', axis=1)
230 weights = dequantize(weights, scales, scale_axis, input.dtype)
231 else:
232 weights = dequantize(weights, scales, scale_axis, input.dtype)
233
234 res = matmul(input, weights, transa=transa, transb=transb)
235 return cast(res, dtype)
236 else:
237 plg_creator = trt.get_plugin_registry().get_plugin_creator(
238 'WeightOnlyQuantMatmul', '1', TRT_LLM_PLUGIN_NAMESPACE)
239 assert plg_creator is not None
240
241 weight_type_id = trt.PluginField("weight_type_id",
242 np.array(weightTypeId, dtype=np.int32),
243 trt.PluginFieldType.INT32)
244
245 p_dtype = default_net().plugin_config.weight_only_quant_matmul_plugin
246 pf_type = trt.PluginField(
247 "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32),
248 trt.PluginFieldType.INT32)
249
250 pfc = trt.PluginFieldCollection([pf_type, weight_type_id])
251 matmul_plug = plg_creator.create_plugin("woq_matmul", pfc)
252 plug_inputs = [input.trt_tensor, weights.trt_tensor, scales.trt_tensor]
253 layer = default_trtnet().add_plugin_v2(plug_inputs, matmul_plug)
254 _add_plugin_info(layer, plg_creator, "woq_matmul", pfc)
255 if not default_net().strongly_typed:
256 layer.get_input(1).set_dynamic_range(-127, 127)
257 return _create_tensor(layer.get_output(0), layer)
258
259
260def weight_only_groupwise_quant_matmul(input: Tensor,

Callers 3

_run_matmulMethod · 0.90
forwardMethod · 0.85
forwardMethod · 0.85

Calls 12

default_netFunction · 0.85
dequantizeFunction · 0.85
matmulFunction · 0.85
castFunction · 0.85
str_dtype_to_trtFunction · 0.85
default_trtnetFunction · 0.85
_add_plugin_infoFunction · 0.85
_create_tensorFunction · 0.85
create_pluginMethod · 0.80
quantizeFunction · 0.70
get_inputMethod · 0.45
get_outputMethod · 0.45

Tested by 1

_run_matmulMethod · 0.72