| 273 | |
| 274 | @dataclasses.dataclass |
| 275 | class LayerQuantConfig(QuantConfig): |
| 276 | quant_algo: Optional[QuantConfig] = None |
| 277 | kv_cache_quant_algo: Optional[QuantConfig] = None |
| 278 | quantized_layers: Optional[Dict[str, QuantConfig]] = None |
| 279 | |
| 280 | def __init__(self, |
| 281 | *, |
| 282 | quant_algo: Optional[QuantConfig] = None, |
| 283 | kv_cache_quant_algo: Optional[QuantConfig] = None, |
| 284 | quantized_layers: Optional[Dict[str, QuantConfig]] = None, |
| 285 | **kwargs): |
| 286 | self.quant_algo = quant_algo |
| 287 | self.quantized_layers = quantized_layers |
| 288 | self.kv_cache_quant_algo = kv_cache_quant_algo |
| 289 | self.auto_quant_mode = {} |
| 290 | for name, layer_config in self.quantized_layers.items(): |
| 291 | self.auto_quant_mode.update({ |
| 292 | name: |
| 293 | QuantMode.from_quant_algo( |
| 294 | layer_config.quant_algo, |
| 295 | self.kv_cache_quant_algo, |
| 296 | ) |
| 297 | }) |
| 298 | for key in kwargs: |
| 299 | logger.warning( |
| 300 | f"Warning: Unrecognized parameter '{key}' with value '{kwargs[key]}'" |
| 301 | ) |
| 302 | |
| 303 | @cached_property |
| 304 | def quant_mode(self): |
| 305 | quant_mode_list = list(set(self.auto_quant_mode.values())) |
| 306 | return QuantModeWrapper(quant_mode_list) |
| 307 | |
| 308 | #@lru_cache(maxsize=None) |
| 309 | def layer_quant_mode(self, layer_name) -> QuantMode: |
| 310 | |
| 311 | for name, quant_mode in self.auto_quant_mode.items(): |
| 312 | if fnmatch.fnmatch(layer_name, name): |
| 313 | return quant_mode |
| 314 | |
| 315 | return QuantMode(0) |
| 316 | |
| 317 | @cached_property |
| 318 | def auto_quant_list(self): |
| 319 | quant_list = [] |
| 320 | for _, layer_config in self.quantized_layers.items(): |
| 321 | quant_list.append(layer_config.quant_algo) |
| 322 | return list(set(quant_list)) |
| 323 | |
| 324 | @classmethod |
| 325 | def from_dict(cls, config: dict): |
| 326 | quantized_layers = config.pop('quantized_layers', {}) |
| 327 | |
| 328 | quantized_layers_dict = { |
| 329 | layer_name: QuantConfig(**layer_config) |
| 330 | for layer_name, layer_config in quantized_layers.items() |
| 331 | } |
| 332 | |