This function in-place modifies weights and model_config, making them compatible with each other. Note: Typically, it should be called before model creation and weight loading. For example, preprocess_weights(weights, model_config) model = XXXForCausalLM(model_config) mo
(weights: Dict[str, torch.Tensor],
model_config: PretrainedConfig,
from_pruned=False)
| 1858 | |
| 1859 | |
| 1860 | def preprocess_weights(weights: Dict[str, torch.Tensor], |
| 1861 | model_config: PretrainedConfig, |
| 1862 | from_pruned=False) -> None: |
| 1863 | """This function in-place modifies weights and model_config, making them compatible with each other. |
| 1864 | |
| 1865 | Note: Typically, it should be called before model creation and weight loading. For example, |
| 1866 | preprocess_weights(weights, model_config) |
| 1867 | model = XXXForCausalLM(model_config) |
| 1868 | model.load(weights) |
| 1869 | """ |
| 1870 | quant_config = model_config.quantization |
| 1871 | quant_algo = quant_config.quant_algo |
| 1872 | |
| 1873 | pattern_info = ['fc', 'gate', 'proj', 'qkv', 'dense'] |
| 1874 | |
| 1875 | def process_kv_scaling_factor(weights: Dict[str, torch.Tensor]): |
| 1876 | new_entries = {} |
| 1877 | names_to_delete = set() |
| 1878 | |
| 1879 | # If k, v cache scaling factors are stored separately, combine them into kv cache scaling factor. |
| 1880 | for name, param in weights.items(): |
| 1881 | if name.endswith('.k_cache_scaling_factor'): |
| 1882 | v_name = name.replace('k_cache_scaling_factor', |
| 1883 | 'v_cache_scaling_factor') |
| 1884 | assert v_name in weights, f"{v_name} not found" |
| 1885 | kv_name = name.replace('k_cache_scaling_factor', |
| 1886 | 'kv_cache_scaling_factor') |
| 1887 | new_entries[kv_name] = torch.max(weights[name], weights[v_name]) |
| 1888 | names_to_delete.update([name, v_name]) |
| 1889 | weights.update(new_entries) |
| 1890 | for k in names_to_delete: |
| 1891 | del weights[k] |
| 1892 | |
| 1893 | new_entries = [] |
| 1894 | # The unified converter generate_tllm_weights() already generates these rcp weights, but legacy |
| 1895 | # converters do not. Handle it here. |
| 1896 | for name, param in weights.items(): |
| 1897 | if name.endswith('.kv_cache_scaling_factor'): |
| 1898 | rcp_name = name.replace('kv_cache_scaling_factor', |
| 1899 | 'kv_cache_rcp_scaling_factor') |
| 1900 | if rcp_name not in weights: |
| 1901 | new_entries.append((rcp_name, torch.reciprocal(param))) |
| 1902 | weights.update(new_entries) |
| 1903 | |
| 1904 | process_kv_scaling_factor(weights) |
| 1905 | |
| 1906 | per_layer_weights = {} |
| 1907 | |
| 1908 | for name, param in weights.items(): |
| 1909 | in_mode = False |
| 1910 | for info in pattern_info: |
| 1911 | pattern = rf'(.*?{info}.*?)' |
| 1912 | pattern_match = re.match(pattern, name) |
| 1913 | if pattern_match: |
| 1914 | base_name = pattern_match.group(1) |
| 1915 | if base_name not in per_layer_weights.keys(): |
| 1916 | per_layer_weights[base_name] = {} |
| 1917 | per_layer_weights[base_name][name] = param |
no test coverage detected