(model_path, device, torch_dtype, use_fast, revision="main")
| 107 | |
| 108 | |
| 109 | def load_compress_model(model_path, device, torch_dtype, use_fast, revision="main"): |
| 110 | # partially load model |
| 111 | # `use_fast=True`` is not supported for some models. |
| 112 | try: |
| 113 | tokenizer = AutoTokenizer.from_pretrained( |
| 114 | model_path, use_fast=use_fast, revision=revision, trust_remote_code=True |
| 115 | ) |
| 116 | except TypeError: |
| 117 | tokenizer = AutoTokenizer.from_pretrained( |
| 118 | model_path, use_fast=~use_fast, revision=revision, trust_remote_code=True |
| 119 | ) |
| 120 | with init_empty_weights(): |
| 121 | # `trust_remote_code` should be set as `True` for both AutoConfig and AutoModel |
| 122 | config = AutoConfig.from_pretrained( |
| 123 | model_path, |
| 124 | low_cpu_mem_usage=True, |
| 125 | torch_dtype=torch_dtype, |
| 126 | trust_remote_code=True, |
| 127 | revision=revision, |
| 128 | ) |
| 129 | # some models are loaded by AutoModel but not AutoModelForCausalLM, |
| 130 | # such as chatglm, chatglm2 |
| 131 | try: |
| 132 | # google/flan-* models are based on an AutoModelForSeq2SeqLM. |
| 133 | if "T5Config" in str(type(config)): |
| 134 | model = AutoModelForSeq2SeqLM.from_config( |
| 135 | config, trust_remote_code=True |
| 136 | ) |
| 137 | else: |
| 138 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) |
| 139 | except NameError: |
| 140 | model = AutoModel.from_config(config, trust_remote_code=True) |
| 141 | linear_weights = get_compressed_list(model) |
| 142 | if os.path.exists(model_path): |
| 143 | # `model_path` is a local folder |
| 144 | base_pattern = os.path.join(model_path, "pytorch_model*.bin") |
| 145 | else: |
| 146 | # `model_path` is a cached Hugging Face repo |
| 147 | # We don't necessarily need to download the model' repo again if there is a cache. |
| 148 | # So check the default huggingface cache first. |
| 149 | model_path_temp = os.path.join( |
| 150 | os.path.expanduser("~"), |
| 151 | ".cache/huggingface/hub", |
| 152 | "models--" + model_path.replace("/", "--"), |
| 153 | "snapshots/", |
| 154 | ) |
| 155 | downloaded = False |
| 156 | if os.path.exists(model_path_temp): |
| 157 | temp_last_dir = os.listdir(model_path_temp)[-1] |
| 158 | model_path_temp = os.path.join(model_path_temp, temp_last_dir) |
| 159 | base_pattern = os.path.join(model_path_temp, "pytorch_model*.bin") |
| 160 | files = glob.glob(base_pattern) |
| 161 | if len(files) > 0: |
| 162 | downloaded = True |
| 163 | |
| 164 | if downloaded: |
| 165 | model_path = model_path_temp |
| 166 | else: |
no test coverage detected
searching dependent graphs…