(attn_implementation=None)
| 786 | |
| 787 | |
| 788 | def train(attn_implementation=None): |
| 789 | global local_rank |
| 790 | |
| 791 | parser = transformers.HfArgumentParser( |
| 792 | (ModelArguments, DataArguments, TrainingArguments)) |
| 793 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| 794 | local_rank = training_args.local_rank |
| 795 | compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) |
| 796 | |
| 797 | bnb_model_from_pretrained_args = {} |
| 798 | if training_args.bits in [4, 8]: |
| 799 | from transformers import BitsAndBytesConfig |
| 800 | bnb_model_from_pretrained_args.update(dict( |
| 801 | device_map={"": training_args.device}, |
| 802 | load_in_4bit=training_args.bits == 4, |
| 803 | load_in_8bit=training_args.bits == 8, |
| 804 | quantization_config=BitsAndBytesConfig( |
| 805 | load_in_4bit=training_args.bits == 4, |
| 806 | load_in_8bit=training_args.bits == 8, |
| 807 | llm_int8_skip_modules=["mm_projector"], |
| 808 | llm_int8_threshold=6.0, |
| 809 | llm_int8_has_fp16_weight=False, |
| 810 | bnb_4bit_compute_dtype=compute_dtype, |
| 811 | bnb_4bit_use_double_quant=training_args.double_quant, |
| 812 | bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} |
| 813 | ) |
| 814 | )) |
| 815 | |
| 816 | if model_args.vision_tower is not None: |
| 817 | if 'mpt' in model_args.model_name_or_path: |
| 818 | config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) |
| 819 | config.attn_config['attn_impl'] = training_args.mpt_attn_impl |
| 820 | model = LlavaMptForCausalLM.from_pretrained( |
| 821 | model_args.model_name_or_path, |
| 822 | config=config, |
| 823 | cache_dir=training_args.cache_dir, |
| 824 | **bnb_model_from_pretrained_args |
| 825 | ) |
| 826 | else: |
| 827 | model = LlavaLlamaForCausalLM.from_pretrained( |
| 828 | model_args.model_name_or_path, |
| 829 | cache_dir=training_args.cache_dir, |
| 830 | attn_implementation=attn_implementation, |
| 831 | torch_dtype=(torch.bfloat16 if training_args.bf16 else None), |
| 832 | **bnb_model_from_pretrained_args |
| 833 | ) |
| 834 | else: |
| 835 | model = transformers.LlamaForCausalLM.from_pretrained( |
| 836 | model_args.model_name_or_path, |
| 837 | cache_dir=training_args.cache_dir, |
| 838 | attn_implementation=attn_implementation, |
| 839 | torch_dtype=(torch.bfloat16 if training_args.bf16 else None), |
| 840 | **bnb_model_from_pretrained_args |
| 841 | ) |
| 842 | model.config.use_cache = False |
| 843 | |
| 844 | if model_args.freeze_backbone: |
| 845 | model.model.requires_grad_(False) |
no test coverage detected