(model, quant_config: QuantConfig, model_config=None)
| 101 | |
| 102 | |
| 103 | def weight_only_quantize(model, quant_config: QuantConfig, model_config=None): |
| 104 | assert quant_config.quant_mode.is_weight_only() |
| 105 | |
| 106 | try: |
| 107 | model_cfg = model.config |
| 108 | except AttributeError: |
| 109 | model_cfg = model_config |
| 110 | |
| 111 | quant_map = { |
| 112 | ColumnLinear: WeightOnlyQuantColumnLinear, |
| 113 | RowLinear: WeightOnlyQuantRowLinear, |
| 114 | Embedding: WeightOnlyQuantEmbedding, |
| 115 | } |
| 116 | |
| 117 | def preprocess_init_params(init_params, name, module): |
| 118 | init_params["quant_mode"] = quant_config.quant_mode |
| 119 | if isinstance(module, ColumnLinear): |
| 120 | module_name = name.rsplit('.', 1)[-1] |
| 121 | init_params["transb"] = module_name == "lm_head" |
| 122 | if "tp_rank" in init_params: |
| 123 | init_params["tp_rank"] = model_cfg.mapping.tp_rank |
| 124 | |
| 125 | model = quantize_layers( |
| 126 | model, |
| 127 | quant_config, |
| 128 | quant_map, |
| 129 | preprocess_init_params, |
| 130 | ) |
| 131 | return model |
| 132 | |
| 133 | |
| 134 | def weight_only_groupwise_quantize(model, |
no test coverage detected