MCPcopy Index your code
hub / github.com/XPixelGroup/DiffBIR / train

Function train

llava/train/train.py:788–987  ·  view source on GitHub ↗
(attn_implementation=None)

Source from the content-addressed store, hash-verified

786
787
788def train(attn_implementation=None):
789 global local_rank
790
791 parser = transformers.HfArgumentParser(
792 (ModelArguments, DataArguments, TrainingArguments))
793 model_args, data_args, training_args = parser.parse_args_into_dataclasses()
794 local_rank = training_args.local_rank
795 compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
796
797 bnb_model_from_pretrained_args = {}
798 if training_args.bits in [4, 8]:
799 from transformers import BitsAndBytesConfig
800 bnb_model_from_pretrained_args.update(dict(
801 device_map={"": training_args.device},
802 load_in_4bit=training_args.bits == 4,
803 load_in_8bit=training_args.bits == 8,
804 quantization_config=BitsAndBytesConfig(
805 load_in_4bit=training_args.bits == 4,
806 load_in_8bit=training_args.bits == 8,
807 llm_int8_skip_modules=["mm_projector"],
808 llm_int8_threshold=6.0,
809 llm_int8_has_fp16_weight=False,
810 bnb_4bit_compute_dtype=compute_dtype,
811 bnb_4bit_use_double_quant=training_args.double_quant,
812 bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
813 )
814 ))
815
816 if model_args.vision_tower is not None:
817 if 'mpt' in model_args.model_name_or_path:
818 config = transformers.AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
819 config.attn_config['attn_impl'] = training_args.mpt_attn_impl
820 model = LlavaMptForCausalLM.from_pretrained(
821 model_args.model_name_or_path,
822 config=config,
823 cache_dir=training_args.cache_dir,
824 **bnb_model_from_pretrained_args
825 )
826 else:
827 model = LlavaLlamaForCausalLM.from_pretrained(
828 model_args.model_name_or_path,
829 cache_dir=training_args.cache_dir,
830 attn_implementation=attn_implementation,
831 torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
832 **bnb_model_from_pretrained_args
833 )
834 else:
835 model = transformers.LlamaForCausalLM.from_pretrained(
836 model_args.model_name_or_path,
837 cache_dir=training_args.cache_dir,
838 attn_implementation=attn_implementation,
839 torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
840 **bnb_model_from_pretrained_args
841 )
842 model.config.use_cache = False
843
844 if model_args.freeze_backbone:
845 model.model.requires_grad_(False)

Callers 3

train_mem.pyFile · 0.90
train_xformers.pyFile · 0.90
train.pyFile · 0.85

Tested by

no test coverage detected