MCPcopy
hub / github.com/lm-sys/FastChat / load_model

Function load_model

fastchat/model/model_adapter.py:181–375  ·  view source on GitHub ↗

Load a model from Hugging Face.

(
    model_path: str,
    device: str = "cuda",
    num_gpus: int = 1,
    max_gpu_memory: Optional[str] = None,
    dtype: Optional[torch.dtype] = None,
    load_8bit: bool = False,
    cpu_offloading: bool = False,
    gptq_config: Optional[GptqConfig] = None,
    awq_config: Optional[AWQConfig] = None,
    exllama_config: Optional[ExllamaConfig] = None,
    xft_config: Optional[XftConfig] = None,
    revision: str = "main",
    debug: bool = False,
)

Source from the content-addressed store, hash-verified

179
180
181def load_model(
182 model_path: str,
183 device: str = "cuda",
184 num_gpus: int = 1,
185 max_gpu_memory: Optional[str] = None,
186 dtype: Optional[torch.dtype] = None,
187 load_8bit: bool = False,
188 cpu_offloading: bool = False,
189 gptq_config: Optional[GptqConfig] = None,
190 awq_config: Optional[AWQConfig] = None,
191 exllama_config: Optional[ExllamaConfig] = None,
192 xft_config: Optional[XftConfig] = None,
193 revision: str = "main",
194 debug: bool = False,
195):
196 """Load a model from Hugging Face."""
197 import accelerate
198
199 # get model adapter
200 adapter = get_model_adapter(model_path)
201
202 # Handle device mapping
203 cpu_offloading = raise_warning_for_incompatible_cpu_offloading_configuration(
204 device, load_8bit, cpu_offloading
205 )
206 if device == "cpu":
207 kwargs = {"torch_dtype": torch.float32}
208 if CPU_ISA in ["avx512_bf16", "amx"]:
209 try:
210 import intel_extension_for_pytorch as ipex
211
212 kwargs = {"torch_dtype": torch.bfloat16}
213 except ImportError:
214 warnings.warn(
215 "Intel Extension for PyTorch is not installed, it can be installed to accelerate cpu inference"
216 )
217 elif device == "cuda":
218 kwargs = {"torch_dtype": torch.float16}
219 if num_gpus != 1:
220 kwargs["device_map"] = "auto"
221 if max_gpu_memory is None:
222 kwargs[
223 "device_map"
224 ] = "sequential" # This is important for not the same VRAM sizes
225 available_gpu_memory = get_gpu_memory(num_gpus)
226 kwargs["max_memory"] = {
227 i: str(int(available_gpu_memory[i] * 0.85)) + "GiB"
228 for i in range(num_gpus)
229 }
230 else:
231 kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)}
232 elif device == "mps":
233 kwargs = {"torch_dtype": torch.float16}
234 import transformers
235
236 version = tuple(int(v) for v in transformers.__version__.split("."))
237 if version < (4, 35, 0):
238 # NOTE: Recent transformers library seems to fix the mps issue, also

Callers 4

__init__Method · 0.90
mainFunction · 0.90
chat_loopFunction · 0.90
get_model_answersFunction · 0.90

Calls 11

get_gpu_memoryFunction · 0.90
load_awq_quantizedFunction · 0.90
load_gptq_quantizedFunction · 0.90
load_exllama_modelFunction · 0.90
load_xft_modelFunction · 0.90
get_model_adapterFunction · 0.85
load_compress_modelMethod · 0.80
toMethod · 0.80
load_modelMethod · 0.45

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…