| 567 | |
| 568 | |
| 569 | def quantize(model, quant_config: Union[QuantConfig, LayerQuantConfig]): |
| 570 | |
| 571 | for name, module, parent in model.named_modules_with_parent(): |
| 572 | |
| 573 | if quant_config.quant_algo == QuantAlgo.MIXED_PRECISION: |
| 574 | layer_quant_mode = quant_config.layer_quant_mode(name) |
| 575 | else: |
| 576 | layer_quant_mode = quant_config.layer_quant_mode |
| 577 | if layer_quant_mode == QuantMode(0): |
| 578 | continue |
| 579 | |
| 580 | layer_quant_cfg = quant_config._get_quant_cfg(name) |
| 581 | |
| 582 | if layer_quant_mode.has_fp8_qdq(): |
| 583 | module = fp8_quantize(module, layer_quant_cfg) |
| 584 | elif layer_quant_mode.has_fp8_rowwise(): |
| 585 | module = fp8_rowwise_quantize(module, layer_quant_cfg) |
| 586 | elif layer_quant_mode.is_qserve_w4a8(): |
| 587 | module = qserve_quantize(module, quant_config) |
| 588 | elif layer_quant_mode.has_nvfp4(): |
| 589 | module = fp4_quantize(module, layer_quant_cfg) |
| 590 | elif layer_quant_mode.has_act_and_weight_quant(): |
| 591 | module = smooth_quantize(module, layer_quant_cfg) |
| 592 | elif layer_quant_mode.is_weight_only(): |
| 593 | if layer_quant_mode.has_per_group_scaling(): |
| 594 | module = weight_only_groupwise_quantize(module, layer_quant_cfg, |
| 595 | model.config) |
| 596 | else: |
| 597 | module = weight_only_quantize(module, layer_quant_cfg, |
| 598 | model.config) |
| 599 | |
| 600 | if parent is not None: # for per layer |
| 601 | module_name = name.rsplit('.', 1)[-1] |
| 602 | setattr(parent, module_name, module) |
| 603 | else: # for all layer |
| 604 | model = module |
| 605 | break |
| 606 | |
| 607 | if quant_config.quant_mode.has_kv_cache_quant(): |
| 608 | model = kv_cache_quantize(model) |
| 609 | |
| 610 | setattr(model, 'quant_mode', quant_config.quant_mode) |
| 611 | return model |