(model, quant_config: QuantConfig)
| 247 | |
| 248 | |
| 249 | def fp8_rowwise_quantize(model, quant_config: QuantConfig): |
| 250 | assert quant_config.quant_mode.has_fp8_rowwise() |
| 251 | |
| 252 | quant_cls_map = { |
| 253 | RmsNorm: Fp8RowwiseRmsNorm, |
| 254 | LayerNorm: Fp8RowwiseLayerNorm, |
| 255 | GatedMLP: Fp8RowwiseGatedMLP, |
| 256 | MLP: Fp8RowwiseMLP, |
| 257 | Attention: Fp8RowwiseAttention, |
| 258 | } |
| 259 | |
| 260 | exclude_modules = quant_config.exclude_modules |
| 261 | if exclude_modules is None: |
| 262 | exclude_modules = [] |
| 263 | # Always exclude these modules for FP8 rowwise |
| 264 | exclude_modules = list( |
| 265 | set(exclude_modules + ['*ln_f', '*ln_embed', '*lm_head'])) |
| 266 | |
| 267 | def extract_layer_idx(name): |
| 268 | ss = name.split('.') |
| 269 | for s in ss: |
| 270 | if s.isdigit(): |
| 271 | return int(s) |
| 272 | return None |
| 273 | |
| 274 | # Meta's LLaMA 3.1 recipe: |
| 275 | # (1) Skip quantization for the first and last Transformer layers |
| 276 | # (2) Skip quantization for the Attention layers |
| 277 | if quant_config.use_meta_recipe: |
| 278 | exclude_modules.extend(['*input_layernorm', '*attention']) |
| 279 | |
| 280 | for name, layer, parent in model.named_modules_with_parent(): |
| 281 | module_name = name.rsplit('.', 1)[-1] |
| 282 | |
| 283 | if quant_config.use_meta_recipe: |
| 284 | local_layer_idx = extract_layer_idx(name) |
| 285 | mapping = model.config.mapping |
| 286 | layers_range = mapping.pp_layers(model.config.num_hidden_layers) |
| 287 | if mapping.is_first_pp_rank() and local_layer_idx == 0: |
| 288 | continue |
| 289 | if mapping.is_last_pp_rank( |
| 290 | ) and local_layer_idx == len(layers_range) - 1: |
| 291 | continue |
| 292 | |
| 293 | quant_cls = None |
| 294 | for cls in quant_cls_map: |
| 295 | if isinstance(layer, cls): |
| 296 | quant_cls = quant_cls_map[cls] |
| 297 | break |
| 298 | if quant_cls is None: |
| 299 | continue |
| 300 | |
| 301 | is_excluded = False |
| 302 | for exclude_module in exclude_modules: |
| 303 | if fnmatch.fnmatchcase(name, exclude_module): |
| 304 | is_excluded = True |
| 305 | break |
| 306 | if is_excluded: |
no test coverage detected