()
| 600 | |
| 601 | |
| 602 | def train(): |
| 603 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| 604 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| 605 | compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) |
| 606 | |
| 607 | bnb_model_from_pretrained_args = {} |
| 608 | if training_args.bits in [4, 8]: |
| 609 | from transformers import BitsAndBytesConfig |
| 610 | from peft import prepare_model_for_int8_training |
| 611 | bnb_model_from_pretrained_args.update(dict( |
| 612 | device_map={"": training_args.device}, |
| 613 | load_in_4bit=training_args.bits == 4, |
| 614 | load_in_8bit=training_args.bits == 8, |
| 615 | quantization_config=BitsAndBytesConfig( |
| 616 | load_in_4bit=training_args.bits == 4, |
| 617 | load_in_8bit=training_args.bits == 8, |
| 618 | llm_int8_threshold=6.0, |
| 619 | llm_int8_has_fp16_weight=False, |
| 620 | bnb_4bit_compute_dtype=compute_dtype, |
| 621 | bnb_4bit_use_double_quant=training_args.double_quant, |
| 622 | bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'} |
| 623 | ) |
| 624 | )) |
| 625 | |
| 626 | if model_args.vision_tower is not None: |
| 627 | if 'mpt' in model_args.model_name_or_path: |
| 628 | model = LlavaMPTForCausalLM.from_pretrained( |
| 629 | model_args.model_name_or_path, |
| 630 | cache_dir=training_args.cache_dir, |
| 631 | **bnb_model_from_pretrained_args |
| 632 | ) |
| 633 | else: |
| 634 | model = LlavaLlamaForCausalLM.from_pretrained( |
| 635 | model_args.model_name_or_path, |
| 636 | cache_dir=training_args.cache_dir, |
| 637 | **bnb_model_from_pretrained_args |
| 638 | ) |
| 639 | else: |
| 640 | model = transformers.LlamaForCausalLM.from_pretrained( |
| 641 | model_args.model_name_or_path, |
| 642 | cache_dir=training_args.cache_dir, |
| 643 | **bnb_model_from_pretrained_args |
| 644 | ) |
| 645 | model.config.use_cache = False |
| 646 | |
| 647 | if model_args.freeze_backbone: |
| 648 | model.model.requires_grad_(False) |
| 649 | |
| 650 | if training_args.bits in [4, 8]: |
| 651 | model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32)) |
| 652 | model = prepare_model_for_int8_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) |
| 653 | |
| 654 | if training_args.gradient_checkpointing and model_args.vision_tower is None: |
| 655 | if hasattr(model, "enable_input_require_grads"): |
| 656 | model.enable_input_require_grads() |
| 657 | else: |
| 658 | def make_inputs_require_grad(module, input, output): |
| 659 | output.requires_grad_(True) |
no test coverage detected