(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs)
| 24 | |
| 25 | |
| 26 | def load_pretrained_model(model_path, model_base, model_name, load_8bit=False, load_4bit=False, device_map="auto", device="cuda", use_flash_attn=False, **kwargs): |
| 27 | kwargs = {"device_map": device_map, **kwargs} |
| 28 | |
| 29 | if device != "cuda": |
| 30 | kwargs['device_map'] = {"": device} |
| 31 | |
| 32 | if load_8bit: |
| 33 | kwargs['load_in_8bit'] = True |
| 34 | elif load_4bit: |
| 35 | kwargs['load_in_4bit'] = True |
| 36 | kwargs['quantization_config'] = BitsAndBytesConfig( |
| 37 | load_in_4bit=True, |
| 38 | bnb_4bit_compute_dtype=torch.float16, |
| 39 | bnb_4bit_use_double_quant=True, |
| 40 | bnb_4bit_quant_type='nf4' |
| 41 | ) |
| 42 | else: |
| 43 | kwargs['torch_dtype'] = torch.float16 |
| 44 | |
| 45 | if use_flash_attn: |
| 46 | kwargs['attn_implementation'] = 'flash_attention_2' |
| 47 | |
| 48 | if 'llava' in model_name.lower(): |
| 49 | # Load LLaVA model |
| 50 | if 'lora' in model_name.lower() and model_base is None: |
| 51 | warnings.warn('There is `lora` in model name but no `model_base` is provided. If you are loading a LoRA model, please provide the `model_base` argument. Detailed instruction: https://github.com/haotian-liu/LLaVA#launch-a-model-worker-lora-weights-unmerged.') |
| 52 | if 'lora' in model_name.lower() and model_base is not None: |
| 53 | from llava.model.language_model.llava_llama import LlavaConfig |
| 54 | lora_cfg_pretrained = LlavaConfig.from_pretrained(model_path) |
| 55 | tokenizer = AutoTokenizer.from_pretrained(model_base, use_fast=False) |
| 56 | print('Loading LLaVA from base model...') |
| 57 | model = LlavaLlamaForCausalLM.from_pretrained(model_base, low_cpu_mem_usage=True, config=lora_cfg_pretrained, **kwargs) |
| 58 | token_num, tokem_dim = model.lm_head.out_features, model.lm_head.in_features |
| 59 | if model.lm_head.weight.shape[0] != token_num: |
| 60 | model.lm_head.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) |
| 61 | model.model.embed_tokens.weight = torch.nn.Parameter(torch.empty(token_num, tokem_dim, device=model.device, dtype=model.dtype)) |
| 62 | |
| 63 | print('Loading additional LLaVA weights...') |
| 64 | if os.path.exists(os.path.join(model_path, 'non_lora_trainables.bin')): |
| 65 | non_lora_trainables = torch.load(os.path.join(model_path, 'non_lora_trainables.bin'), map_location='cpu') |
| 66 | else: |
| 67 | # this is probably from HF Hub |
| 68 | from huggingface_hub import hf_hub_download |
| 69 | def load_from_hf(repo_id, filename, subfolder=None): |
| 70 | cache_file = hf_hub_download( |
| 71 | repo_id=repo_id, |
| 72 | filename=filename, |
| 73 | subfolder=subfolder) |
| 74 | return torch.load(cache_file, map_location='cpu') |
| 75 | non_lora_trainables = load_from_hf(model_path, 'non_lora_trainables.bin') |
| 76 | non_lora_trainables = {(k[11:] if k.startswith('base_model.') else k): v for k, v in non_lora_trainables.items()} |
| 77 | if any(k.startswith('model.model.') for k in non_lora_trainables): |
| 78 | non_lora_trainables = {(k[6:] if k.startswith('model.') else k): v for k, v in non_lora_trainables.items()} |
| 79 | model.load_state_dict(non_lora_trainables, strict=False) |
| 80 | |
| 81 | from peft import PeftModel |
| 82 | print('Loading LoRA weights...') |
| 83 | model = PeftModel.from_pretrained(model, model_path) |
no test coverage detected