Load a model from Hugging Face.
(
model_path: str,
device: str = "cuda",
num_gpus: int = 1,
max_gpu_memory: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
xft_config: Optional[XftConfig] = None,
revision: str = "main",
debug: bool = False,
)
| 179 | |
| 180 | |
| 181 | def load_model( |
| 182 | model_path: str, |
| 183 | device: str = "cuda", |
| 184 | num_gpus: int = 1, |
| 185 | max_gpu_memory: Optional[str] = None, |
| 186 | dtype: Optional[torch.dtype] = None, |
| 187 | load_8bit: bool = False, |
| 188 | cpu_offloading: bool = False, |
| 189 | gptq_config: Optional[GptqConfig] = None, |
| 190 | awq_config: Optional[AWQConfig] = None, |
| 191 | exllama_config: Optional[ExllamaConfig] = None, |
| 192 | xft_config: Optional[XftConfig] = None, |
| 193 | revision: str = "main", |
| 194 | debug: bool = False, |
| 195 | ): |
| 196 | """Load a model from Hugging Face.""" |
| 197 | import accelerate |
| 198 | |
| 199 | # get model adapter |
| 200 | adapter = get_model_adapter(model_path) |
| 201 | |
| 202 | # Handle device mapping |
| 203 | cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration( |
| 204 | device, load_8bit, cpu_offloading |
| 205 | ) |
| 206 | if device == "cpu": |
| 207 | kwargs = {"torch_dtype": torch.float32} |
| 208 | if CPU_ISA in ["avx512_bf16", "amx"]: |
| 209 | try: |
| 210 | import intel_extension_for_pytorch as ipex |
| 211 | |
| 212 | kwargs = {"torch_dtype": torch.bfloat16} |
| 213 | except ImportError: |
| 214 | warnings.warn( |
| 215 | "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference" |
| 216 | ) |
| 217 | elif device == "cuda": |
| 218 | kwargs = {"torch_dtype": torch.float16} |
| 219 | if num_gpus != 1: |
| 220 | kwargs["device_map"] = "auto" |
| 221 | if max_gpu_memory is None: |
| 222 | kwargs[ |
| 223 | "device_map" |
| 224 | ] = "sequential" # This is important for not the same VRAM sizes |
| 225 | available_gpu_memory = get_gpu_memory(num_gpus) |
| 226 | kwargs["max_memory"] = { |
| 227 | i: str(int(available_gpu_memory[i] * 0.85)) + "GiB" |
| 228 | for i in range(num_gpus) |
| 229 | } |
| 230 | else: |
| 231 | kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} |
| 232 | elif device == "mps": |
| 233 | kwargs = {"torch_dtype": torch.float16} |
| 234 | import transformers |
| 235 | |
| 236 | version = tuple(int(v) for v in transformers.__version__.split(".")) |
| 237 | if version < (4, 35, 0): |
| 238 | # NOTE: Recent transformers library seems to fix the mps issue, also |
no test coverage detected
searching dependent graphs…