| 215 | |
| 216 | |
| 217 | def weight_only_quant_matmul(input: Tensor, |
| 218 | weights: Tensor, |
| 219 | scales: Tensor, |
| 220 | weightTypeId: int, |
| 221 | dtype: str = 'float16', |
| 222 | transa: bool = False, |
| 223 | transb: bool = False) -> Tensor: |
| 224 | if not default_net( |
| 225 | ).plugin_config.weight_only_quant_matmul_plugin or transa or transb: |
| 226 | scale_axis = 0 if transb else 1 |
| 227 | if weights.dtype != trt.int8: |
| 228 | # Q->DQ |
| 229 | weights = quantize(weights, scales, dtype='int8', axis=1) |
| 230 | weights = dequantize(weights, scales, scale_axis, input.dtype) |
| 231 | else: |
| 232 | weights = dequantize(weights, scales, scale_axis, input.dtype) |
| 233 | |
| 234 | res = matmul(input, weights, transa=transa, transb=transb) |
| 235 | return cast(res, dtype) |
| 236 | else: |
| 237 | plg_creator = trt.get_plugin_registry().get_plugin_creator( |
| 238 | 'WeightOnlyQuantMatmul', '1', TRT_LLM_PLUGIN_NAMESPACE) |
| 239 | assert plg_creator is not None |
| 240 | |
| 241 | weight_type_id = trt.PluginField("weight_type_id", |
| 242 | np.array(weightTypeId, dtype=np.int32), |
| 243 | trt.PluginFieldType.INT32) |
| 244 | |
| 245 | p_dtype = default_net().plugin_config.weight_only_quant_matmul_plugin |
| 246 | pf_type = trt.PluginField( |
| 247 | "type_id", np.array([int(str_dtype_to_trt(p_dtype))], np.int32), |
| 248 | trt.PluginFieldType.INT32) |
| 249 | |
| 250 | pfc = trt.PluginFieldCollection([pf_type, weight_type_id]) |
| 251 | matmul_plug = plg_creator.create_plugin("woq_matmul", pfc) |
| 252 | plug_inputs = [input.trt_tensor, weights.trt_tensor, scales.trt_tensor] |
| 253 | layer = default_trtnet().add_plugin_v2(plug_inputs, matmul_plug) |
| 254 | _add_plugin_info(layer, plg_creator, "woq_matmul", pfc) |
| 255 | if not default_net().strongly_typed: |
| 256 | layer.get_input(1).set_dynamic_range(-127, 127) |
| 257 | return _create_tensor(layer.get_output(0), layer) |
| 258 | |
| 259 | |
| 260 | def weight_only_groupwise_quant_matmul(input: Tensor, |