MCPcopy
hub / github.com/NVIDIA/TensorRT-LLM / preprocess_weights

Function preprocess_weights

tensorrt_llm/models/modeling_utils.py:1860–1950  ·  view source on GitHub ↗

This function in-place modifies weights and model_config, making them compatible with each other. Note: Typically, it should be called before model creation and weight loading. For example, preprocess_weights(weights, model_config) model = XXXForCausalLM(model_config) mo

(weights: Dict[str, torch.Tensor],
                       model_config: PretrainedConfig,
                       from_pruned=False)

Source from the content-addressed store, hash-verified

1858
1859
1860def preprocess_weights(weights: Dict[str, torch.Tensor],
1861 model_config: PretrainedConfig,
1862 from_pruned=False) -> None:
1863 """This function in-place modifies weights and model_config, making them compatible with each other.
1864
1865 Note: Typically, it should be called before model creation and weight loading. For example,
1866 preprocess_weights(weights, model_config)
1867 model = XXXForCausalLM(model_config)
1868 model.load(weights)
1869 """
1870 quant_config = model_config.quantization
1871 quant_algo = quant_config.quant_algo
1872
1873 pattern_info = ['fc', 'gate', 'proj', 'qkv', 'dense']
1874
1875 def process_kv_scaling_factor(weights: Dict[str, torch.Tensor]):
1876 new_entries = {}
1877 names_to_delete = set()
1878
1879 # If k, v cache scaling factors are stored separately, combine them into kv cache scaling factor.
1880 for name, param in weights.items():
1881 if name.endswith('.k_cache_scaling_factor'):
1882 v_name = name.replace('k_cache_scaling_factor',
1883 'v_cache_scaling_factor')
1884 assert v_name in weights, f"{v_name} not found"
1885 kv_name = name.replace('k_cache_scaling_factor',
1886 'kv_cache_scaling_factor')
1887 new_entries[kv_name] = torch.max(weights[name], weights[v_name])
1888 names_to_delete.update([name, v_name])
1889 weights.update(new_entries)
1890 for k in names_to_delete:
1891 del weights[k]
1892
1893 new_entries = []
1894 # The unified converter generate_tllm_weights() already generates these rcp weights, but legacy
1895 # converters do not. Handle it here.
1896 for name, param in weights.items():
1897 if name.endswith('.kv_cache_scaling_factor'):
1898 rcp_name = name.replace('kv_cache_scaling_factor',
1899 'kv_cache_rcp_scaling_factor')
1900 if rcp_name not in weights:
1901 new_entries.append((rcp_name, torch.reciprocal(param)))
1902 weights.update(new_entries)
1903
1904 process_kv_scaling_factor(weights)
1905
1906 per_layer_weights = {}
1907
1908 for name, param in weights.items():
1909 in_mode = False
1910 for info in pattern_info:
1911 pattern = rf'(.*?{info}.*?)'
1912 pattern_match = re.match(pattern, name)
1913 if pattern_match:
1914 base_name = pattern_match.group(1)
1915 if base_name not in per_layer_weights.keys():
1916 per_layer_weights[base_name] = {}
1917 per_layer_weights[base_name][name] = param

Callers 2

from_checkpointMethod · 0.85
from_hugging_faceMethod · 0.85

Calls 6

matchMethod · 0.45
keysMethod · 0.45
_get_quant_cfgMethod · 0.45
updateMethod · 0.45

Tested by

no test coverage detected