MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / preprocess_perlayer_weights

Function preprocess_perlayer_weights

tensorrt_llm/models/modeling_utils.py:1738–1857  ·  view source on GitHub ↗
(weights,
                                model_config,
                                quant_algo,
                                from_pruned=False)

Source from the content-addressed store, hash-verified

1736
1737
1738def preprocess_perlayer_weights(weights,
1739 model_config,
1740 quant_algo,
1741 from_pruned=False):
1742 exclude_modules = model_config.quantization.exclude_modules
1743
1744 # INT4_AWQ
1745 if quant_algo == QuantAlgo.W4A8_AWQ or quant_algo == QuantAlgo.W4A16_AWQ:
1746 preprocessor = preprocess_weights_for_mixed_gemm
1747 if quant_algo == QuantAlgo.W4A8_AWQ:
1748 activation_type = torch.float8_e4m3fn
1749 elif quant_algo == QuantAlgo.W4A16_AWQ:
1750 activation_type = torch.float16
1751 for name, param in weights.items():
1752 if from_pruned and param.numel() == 0:
1753 continue
1754 if name.endswith('weight') and param.dtype == torch.int8:
1755 dtype = torch.float16
1756 if model_config.dtype == "bfloat16":
1757 dtype = torch.bfloat16
1758 weights[name] = preprocessor(param.transpose(-1, -2),
1759 torch.quint4x2,
1760 activation_type).view(dtype)
1761 if name.endswith('weights_scaling_factor'):
1762 weights[name] = param.transpose(-1, -2).contiguous().to(
1763 str_dtype_to_torch(model_config.dtype))
1764 if name.endswith('prequant_scaling_factor'):
1765 if len(weights[name].shape) == 2:
1766 # MoE experts share the same scaling factor.
1767 param = param[0, :]
1768 weights[name] = param.reshape(1, -1)
1769 if model_config.mapping.tp_rank > 0:
1770 if name.endswith('attention.dense.bias') or name.endswith(
1771 'mlp.proj.bias'):
1772 weights[name] = torch.zeros_like(param)
1773
1774 if quant_algo == QuantAlgo.W4A8_AWQ:
1775 for name in list(weights):
1776 if name.endswith('weights_scaling_factor'):
1777 activation_scaling_factor = weights.pop(
1778 name.replace('weights_scaling_factor',
1779 'activation_scaling_factor'))
1780 weights_scaling_factor_2 = weights.pop(
1781 name.replace('weights_scaling_factor',
1782 'weights_scaling_factor_2'))
1783 weights[name] /= weights_scaling_factor_2
1784 weights[name] = weights[name].to(torch.float16).view(
1785 str_dtype_to_torch(model_config.dtype))
1786 weights[name.replace(
1787 'weights_scaling_factor',
1788 'prequant_scaling_factor')] /= activation_scaling_factor
1789 weights[name.replace(
1790 'weights_scaling_factor', 'alpha'
1791 )] = activation_scaling_factor * weights_scaling_factor_2
1792 weights[name.replace('weights_scaling_factor',
1793 'activation_scaling_factor'
1794 )] = activation_scaling_factor
1795

Callers 1

preprocess_weightsFunction · 0.85

Calls 8

str_dtype_to_torchFunction · 0.85
transposeMethod · 0.80
popMethod · 0.80
replaceMethod · 0.80
numelMethod · 0.45
viewMethod · 0.45
toMethod · 0.45

Tested by

no test coverage detected