MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / fp8_rowwise_quantize

Function fp8_rowwise_quantize

tensorrt_llm/quantization/quantize.py:249–321  ·  view source on GitHub ↗
(model, quant_config: QuantConfig)

Source from the content-addressed store, hash-verified

247
248
249def fp8_rowwise_quantize(model, quant_config: QuantConfig):
250 assert quant_config.quant_mode.has_fp8_rowwise()
251
252 quant_cls_map = {
253 RmsNorm: Fp8RowwiseRmsNorm,
254 LayerNorm: Fp8RowwiseLayerNorm,
255 GatedMLP: Fp8RowwiseGatedMLP,
256 MLP: Fp8RowwiseMLP,
257 Attention: Fp8RowwiseAttention,
258 }
259
260 exclude_modules = quant_config.exclude_modules
261 if exclude_modules is None:
262 exclude_modules = []
263 # Always exclude these modules for FP8 rowwise
264 exclude_modules = list(
265 set(exclude_modules + ['*ln_f', '*ln_embed', '*lm_head']))
266
267 def extract_layer_idx(name):
268 ss = name.split('.')
269 for s in ss:
270 if s.isdigit():
271 return int(s)
272 return None
273
274 # Meta's LLaMA 3.1 recipe:
275 # (1) Skip quantization for the first and last Transformer layers
276 # (2) Skip quantization for the Attention layers
277 if quant_config.use_meta_recipe:
278 exclude_modules.extend(['*input_layernorm', '*attention'])
279
280 for name, layer, parent in model.named_modules_with_parent():
281 module_name = name.rsplit('.', 1)[-1]
282
283 if quant_config.use_meta_recipe:
284 local_layer_idx = extract_layer_idx(name)
285 mapping = model.config.mapping
286 layers_range = mapping.pp_layers(model.config.num_hidden_layers)
287 if mapping.is_first_pp_rank() and local_layer_idx == 0:
288 continue
289 if mapping.is_last_pp_rank(
290 ) and local_layer_idx == len(layers_range) - 1:
291 continue
292
293 quant_cls = None
294 for cls in quant_cls_map:
295 if isinstance(layer, cls):
296 quant_cls = quant_cls_map[cls]
297 break
298 if quant_cls is None:
299 continue
300
301 is_excluded = False
302 for exclude_module in exclude_modules:
303 if fnmatch.fnmatchcase(name, exclude_module):
304 is_excluded = True
305 break
306 if is_excluded:

Callers 1

quantizeFunction · 0.85

Calls 7

get_init_paramsFunction · 0.85
pp_layersMethod · 0.80
extract_layer_idxFunction · 0.70
has_fp8_rowwiseMethod · 0.45
is_first_pp_rankMethod · 0.45
is_last_pp_rankMethod · 0.45

Tested by

no test coverage detected