MCPcopy
hub / github.com/apple/ml-mgie / train

Function train

mgie_train.py:602–830  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

600
601
602def train():
603 parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
604 model_args, data_args, training_args = parser.parse_args_into_dataclasses()
605 compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
606
607 bnb_model_from_pretrained_args = {}
608 if training_args.bits in [4, 8]:
609 from transformers import BitsAndBytesConfig
610 from peft import prepare_model_for_int8_training
611 bnb_model_from_pretrained_args.update(dict(
612 device_map={"": training_args.device},
613 load_in_4bit=training_args.bits == 4,
614 load_in_8bit=training_args.bits == 8,
615 quantization_config=BitsAndBytesConfig(
616 load_in_4bit=training_args.bits == 4,
617 load_in_8bit=training_args.bits == 8,
618 llm_int8_threshold=6.0,
619 llm_int8_has_fp16_weight=False,
620 bnb_4bit_compute_dtype=compute_dtype,
621 bnb_4bit_use_double_quant=training_args.double_quant,
622 bnb_4bit_quant_type=training_args.quant_type # {'fp4', 'nf4'}
623 )
624 ))
625
626 if model_args.vision_tower is not None:
627 if 'mpt' in model_args.model_name_or_path:
628 model = LlavaMPTForCausalLM.from_pretrained(
629 model_args.model_name_or_path,
630 cache_dir=training_args.cache_dir,
631 **bnb_model_from_pretrained_args
632 )
633 else:
634 model = LlavaLlamaForCausalLM.from_pretrained(
635 model_args.model_name_or_path,
636 cache_dir=training_args.cache_dir,
637 **bnb_model_from_pretrained_args
638 )
639 else:
640 model = transformers.LlamaForCausalLM.from_pretrained(
641 model_args.model_name_or_path,
642 cache_dir=training_args.cache_dir,
643 **bnb_model_from_pretrained_args
644 )
645 model.config.use_cache = False
646
647 if model_args.freeze_backbone:
648 model.model.requires_grad_(False)
649
650 if training_args.bits in [4, 8]:
651 model.config.torch_dtype=(torch.float32 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
652 model = prepare_model_for_int8_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
653
654 if training_args.gradient_checkpointing and model_args.vision_tower is None:
655 if hasattr(model, "enable_input_require_grads"):
656 model.enable_input_require_grads()
657 else:
658 def make_inputs_require_grad(module, input, output):
659 output.requires_grad_(True)

Callers 1

mgie_train.pyFile · 0.85

Tested by

no test coverage detected