Patch quantization config.
(hf_config: Any, model_format: str = None)
| 278 | |
| 279 | |
| 280 | def _patch_quantization_config(hf_config: Any, model_format: str = None): |
| 281 | """Patch quantization config.""" |
| 282 | if model_format is None: |
| 283 | return hf_config |
| 284 | |
| 285 | # skip the quantized llm and vlm models |
| 286 | if hasattr(hf_config, 'quantization_config') or \ |
| 287 | (hasattr(hf_config, 'llm_config') and hasattr(hf_config.llm_config, 'quantization_config')) \ |
| 288 | or (hasattr(hf_config, 'text_config') and hasattr(hf_config.text_config, 'quantization_config')): |
| 289 | logger.warning('Can not perform weight quantization on quantized model.') |
| 290 | return hf_config |
| 291 | |
| 292 | if model_format == 'fp8': |
| 293 | logger.debug('Patch quantization config for fp8.') |
| 294 | from lmdeploy.pytorch.envs import scale_fmt |
| 295 | quantization_config = dict(quant_method='fp8', fmt='e4m3', weight_block_size=[128, 128], scale_fmt=scale_fmt) |
| 296 | else: |
| 297 | raise RuntimeError(f'Unsupported weight quantization method: {model_format}') |
| 298 | |
| 299 | hf_config.quantization_config = quantization_config |
| 300 | # for vlm models |
| 301 | if hasattr(hf_config, 'text_config'): |
| 302 | hf_config.text_config.quantization_config = quantization_config |
| 303 | elif hasattr(hf_config, 'llm_config'): |
| 304 | hf_config.llm_config.quantization_config = quantization_config |
| 305 | |
| 306 | return hf_config |
| 307 | |
| 308 | |
| 309 | @dataclass |