MCPcopy Index your code
hub / github.com/hpcaitech/ColossalAI / quantize_model

Function quantize_model

colossalai/quantization/bnb.py:30–106  ·  view source on GitHub ↗

This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`. We will quantize the model and put the model on the GPU. Args: model (`torch.nn.Module`): Input model. The model already loaded bnb_quantizatio

(
    model: torch.nn.Module,
    bnb_quantization_config: BnbQuantizationConfig,
)

Source from the content-addressed store, hash-verified

28
29
30def quantize_model(
31 model: torch.nn.Module,
32 bnb_quantization_config: BnbQuantizationConfig,
33):
34 """
35 This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`.
36 We will quantize the model and put the model on the GPU.
37
38 Args:
39 model (`torch.nn.Module`):
40 Input model. The model already loaded
41 bnb_quantization_config (`BnbQuantizationConfig`):
42 The bitsandbytes quantization parameters
43
44 Returns:
45 `torch.nn.Module`: The quantized model
46 """
47
48 load_in_4bit = bnb_quantization_config.load_in_4bit
49 load_in_8bit = bnb_quantization_config.load_in_8bit
50
51 if load_in_8bit and not IS_8BIT_BNB_AVAILABLE:
52 raise ImportError(
53 "You have a version of `bitsandbytes` that is not compatible with 8bit quantization,"
54 " make sure you have the latest version of `bitsandbytes` installed."
55 )
56 if load_in_4bit and not IS_4BIT_BNB_AVAILABLE:
57 raise ValueError(
58 "You have a version of `bitsandbytes` that is not compatible with 4bit quantization,"
59 "make sure you have the latest version of `bitsandbytes` installed."
60 )
61
62 # We keep some modules such as the lm_head in their original dtype for numerical stability reasons
63 if bnb_quantization_config.skip_modules is None:
64 bnb_quantization_config.skip_modules = get_keys_to_not_convert(model)
65
66 modules_to_not_convert = bnb_quantization_config.skip_modules
67
68 # We add the modules we want to keep in full precision
69 if bnb_quantization_config.keep_in_fp32_modules is None:
70 bnb_quantization_config.keep_in_fp32_modules = []
71 keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules
72
73 # compatibility with peft
74 model.is_loaded_in_4bit = load_in_4bit
75 model.is_loaded_in_8bit = load_in_8bit
76
77 # assert model_device is cuda
78 model_device = next(model.parameters()).device
79
80 model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert)
81
82 # convert param to the right dtype
83 dtype = bnb_quantization_config.torch_dtype
84 for name, param in model.state_dict().items():
85 if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules):
86 param.to(torch.float32)
87 if param.dtype != torch.float32:

Callers 3

enable_loraMethod · 0.90
enable_loraMethod · 0.90
enable_loraMethod · 0.90

Calls 11

get_keys_to_not_convertFunction · 0.85
replace_with_bnb_layersFunction · 0.85
parametersMethod · 0.45
state_dictMethod · 0.45
toMethod · 0.45
replaceMethod · 0.45
cudaMethod · 0.45
current_deviceMethod · 0.45
empty_cacheMethod · 0.45
is_availableMethod · 0.45
infoMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…