(weights,
model_config,
quant_algo,
from_pruned=False)
| 1736 | |
| 1737 | |
| 1738 | def preprocess_perlayer_weights(weights, |
| 1739 | model_config, |
| 1740 | quant_algo, |
| 1741 | from_pruned=False): |
| 1742 | exclude_modules = model_config.quantization.exclude_modules |
| 1743 | |
| 1744 | # INT4_AWQ |
| 1745 | if quant_algo == QuantAlgo.W4A8_AWQ or quant_algo == QuantAlgo.W4A16_AWQ: |
| 1746 | preprocessor = preprocess_weights_for_mixed_gemm |
| 1747 | if quant_algo == QuantAlgo.W4A8_AWQ: |
| 1748 | activation_type = torch.float8_e4m3fn |
| 1749 | elif quant_algo == QuantAlgo.W4A16_AWQ: |
| 1750 | activation_type = torch.float16 |
| 1751 | for name, param in weights.items(): |
| 1752 | if from_pruned and param.numel() == 0: |
| 1753 | continue |
| 1754 | if name.endswith('weight') and param.dtype == torch.int8: |
| 1755 | dtype = torch.float16 |
| 1756 | if model_config.dtype == "bfloat16": |
| 1757 | dtype = torch.bfloat16 |
| 1758 | weights[name] = preprocessor(param.transpose(-1, -2), |
| 1759 | torch.quint4x2, |
| 1760 | activation_type).view(dtype) |
| 1761 | if name.endswith('weights_scaling_factor'): |
| 1762 | weights[name] = param.transpose(-1, -2).contiguous().to( |
| 1763 | str_dtype_to_torch(model_config.dtype)) |
| 1764 | if name.endswith('prequant_scaling_factor'): |
| 1765 | if len(weights[name].shape) == 2: |
| 1766 | # MoE experts share the same scaling factor. |
| 1767 | param = param[0, :] |
| 1768 | weights[name] = param.reshape(1, -1) |
| 1769 | if model_config.mapping.tp_rank > 0: |
| 1770 | if name.endswith('attention.dense.bias') or name.endswith( |
| 1771 | 'mlp.proj.bias'): |
| 1772 | weights[name] = torch.zeros_like(param) |
| 1773 | |
| 1774 | if quant_algo == QuantAlgo.W4A8_AWQ: |
| 1775 | for name in list(weights): |
| 1776 | if name.endswith('weights_scaling_factor'): |
| 1777 | activation_scaling_factor = weights.pop( |
| 1778 | name.replace('weights_scaling_factor', |
| 1779 | 'activation_scaling_factor')) |
| 1780 | weights_scaling_factor_2 = weights.pop( |
| 1781 | name.replace('weights_scaling_factor', |
| 1782 | 'weights_scaling_factor_2')) |
| 1783 | weights[name] /= weights_scaling_factor_2 |
| 1784 | weights[name] = weights[name].to(torch.float16).view( |
| 1785 | str_dtype_to_torch(model_config.dtype)) |
| 1786 | weights[name.replace( |
| 1787 | 'weights_scaling_factor', |
| 1788 | 'prequant_scaling_factor')] /= activation_scaling_factor |
| 1789 | weights[name.replace( |
| 1790 | 'weights_scaling_factor', 'alpha' |
| 1791 | )] = activation_scaling_factor * weights_scaling_factor_2 |
| 1792 | weights[name.replace('weights_scaling_factor', |
| 1793 | 'activation_scaling_factor' |
| 1794 | )] = activation_scaling_factor |
| 1795 |
no test coverage detected