MCPcopy
hub / github.com/lm-sys/FastChat / load_compress_model

Function load_compress_model

fastchat/model/compression.py:109–223  ·  view source on GitHub ↗
(model_path, device, torch_dtype, use_fast, revision="main")

Source from the content-addressed store, hash-verified

107
108
109def load_compress_model(model_path, device, torch_dtype, use_fast, revision="main"):
110 # partially load model
111 # `use_fast=True`` is not supported for some models.
112 try:
113 tokenizer = AutoTokenizer.from_pretrained(
114 model_path, use_fast=use_fast, revision=revision, trust_remote_code=True
115 )
116 except TypeError:
117 tokenizer = AutoTokenizer.from_pretrained(
118 model_path, use_fast=~use_fast, revision=revision, trust_remote_code=True
119 )
120 with init_empty_weights():
121 # `trust_remote_code` should be set as `True` for both AutoConfig and AutoModel
122 config = AutoConfig.from_pretrained(
123 model_path,
124 low_cpu_mem_usage=True,
125 torch_dtype=torch_dtype,
126 trust_remote_code=True,
127 revision=revision,
128 )
129 # some models are loaded by AutoModel but not AutoModelForCausalLM,
130 # such as chatglm, chatglm2
131 try:
132 # google/flan-* models are based on an AutoModelForSeq2SeqLM.
133 if "T5Config" in str(type(config)):
134 model = AutoModelForSeq2SeqLM.from_config(
135 config, trust_remote_code=True
136 )
137 else:
138 model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
139 except NameError:
140 model = AutoModel.from_config(config, trust_remote_code=True)
141 linear_weights = get_compressed_list(model)
142 if os.path.exists(model_path):
143 # `model_path` is a local folder
144 base_pattern = os.path.join(model_path, "pytorch_model*.bin")
145 else:
146 # `model_path` is a cached Hugging Face repo
147 # We don't necessarily need to download the model' repo again if there is a cache.
148 # So check the default huggingface cache first.
149 model_path_temp = os.path.join(
150 os.path.expanduser("~"),
151 ".cache/huggingface/hub",
152 "models--" + model_path.replace("/", "--"),
153 "snapshots/",
154 )
155 downloaded = False
156 if os.path.exists(model_path_temp):
157 temp_last_dir = os.listdir(model_path_temp)[-1]
158 model_path_temp = os.path.join(model_path_temp, temp_last_dir)
159 base_pattern = os.path.join(model_path_temp, "pytorch_model*.bin")
160 files = glob.glob(base_pattern)
161 if len(files) > 0:
162 downloaded = True
163
164 if downloaded:
165 model_path = model_path_temp
166 else:

Callers 1

load_compress_modelMethod · 0.90

Calls 4

get_compressed_listFunction · 0.85
compressFunction · 0.85
apply_compressed_weightFunction · 0.85
toMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…