This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`. We will quantize the model and put the model on the GPU. Args: model (`torch.nn.Module`): Input model. The model already loaded bnb_quantizatio
(
model: torch.nn.Module,
bnb_quantization_config: BnbQuantizationConfig,
)
| 28 | |
| 29 | |
| 30 | def quantize_model( |
| 31 | model: torch.nn.Module, |
| 32 | bnb_quantization_config: BnbQuantizationConfig, |
| 33 | ): |
| 34 | """ |
| 35 | This function will quantize the input loaded model with the associated config passed in `bnb_quantization_config`. |
| 36 | We will quantize the model and put the model on the GPU. |
| 37 | |
| 38 | Args: |
| 39 | model (`torch.nn.Module`): |
| 40 | Input model. The model already loaded |
| 41 | bnb_quantization_config (`BnbQuantizationConfig`): |
| 42 | The bitsandbytes quantization parameters |
| 43 | |
| 44 | Returns: |
| 45 | `torch.nn.Module`: The quantized model |
| 46 | """ |
| 47 | |
| 48 | load_in_4bit = bnb_quantization_config.load_in_4bit |
| 49 | load_in_8bit = bnb_quantization_config.load_in_8bit |
| 50 | |
| 51 | if load_in_8bit and not IS_8BIT_BNB_AVAILABLE: |
| 52 | raise ImportError( |
| 53 | "You have a version of `bitsandbytes` that is not compatible with 8bit quantization," |
| 54 | " make sure you have the latest version of `bitsandbytes` installed." |
| 55 | ) |
| 56 | if load_in_4bit and not IS_4BIT_BNB_AVAILABLE: |
| 57 | raise ValueError( |
| 58 | "You have a version of `bitsandbytes` that is not compatible with 4bit quantization," |
| 59 | "make sure you have the latest version of `bitsandbytes` installed." |
| 60 | ) |
| 61 | |
| 62 | # We keep some modules such as the lm_head in their original dtype for numerical stability reasons |
| 63 | if bnb_quantization_config.skip_modules is None: |
| 64 | bnb_quantization_config.skip_modules = get_keys_to_not_convert(model) |
| 65 | |
| 66 | modules_to_not_convert = bnb_quantization_config.skip_modules |
| 67 | |
| 68 | # We add the modules we want to keep in full precision |
| 69 | if bnb_quantization_config.keep_in_fp32_modules is None: |
| 70 | bnb_quantization_config.keep_in_fp32_modules = [] |
| 71 | keep_in_fp32_modules = bnb_quantization_config.keep_in_fp32_modules |
| 72 | |
| 73 | # compatibility with peft |
| 74 | model.is_loaded_in_4bit = load_in_4bit |
| 75 | model.is_loaded_in_8bit = load_in_8bit |
| 76 | |
| 77 | # assert model_device is cuda |
| 78 | model_device = next(model.parameters()).device |
| 79 | |
| 80 | model = replace_with_bnb_layers(model, bnb_quantization_config, modules_to_not_convert=modules_to_not_convert) |
| 81 | |
| 82 | # convert param to the right dtype |
| 83 | dtype = bnb_quantization_config.torch_dtype |
| 84 | for name, param in model.state_dict().items(): |
| 85 | if any(module_to_keep_in_fp32 in name for module_to_keep_in_fp32 in keep_in_fp32_modules): |
| 86 | param.to(torch.float32) |
| 87 | if param.dtype != torch.float32: |
no test coverage detected
searching dependent graphs…