(model)
| 554 | |
| 555 | # Now consider the kv cache is enabled for all layers |
| 556 | def kv_cache_quantize(model): |
| 557 | for name, module in model.named_modules(): |
| 558 | if isinstance(module, |
| 559 | (Attention, SmoothQuantAttention, Fp8RowwiseAttention)): |
| 560 | # for dequant |
| 561 | module.kv_cache_scaling_factor = Parameter(shape=(1, ), |
| 562 | dtype='float32') |
| 563 | # for quant |
| 564 | module.kv_cache_rcp_scaling_factor = Parameter(shape=(1, ), |
| 565 | dtype='float32') |
| 566 | return model |
| 567 | |
| 568 | |
| 569 | def quantize(model, quant_config: Union[QuantConfig, LayerQuantConfig]): |
no test coverage detected