()
| 102 | |
| 103 | |
| 104 | def train(): |
| 105 | parser = transformers.HfArgumentParser( |
| 106 | (ModelArguments, DataArguments, TrainingArguments, LoraArguments) |
| 107 | ) |
| 108 | ( |
| 109 | model_args, |
| 110 | data_args, |
| 111 | training_args, |
| 112 | lora_args, |
| 113 | ) = parser.parse_args_into_dataclasses() |
| 114 | |
| 115 | if training_args.flash_attn: |
| 116 | replace_llama_attn_with_flash_attn() |
| 117 | |
| 118 | device_map = None |
| 119 | world_size = int(os.environ.get("WORLD_SIZE", 1)) |
| 120 | ddp = world_size != 1 |
| 121 | if lora_args.q_lora: |
| 122 | device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None |
| 123 | if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(): |
| 124 | logging.warning( |
| 125 | "FSDP and ZeRO3 are both currently incompatible with QLoRA." |
| 126 | ) |
| 127 | |
| 128 | compute_dtype = ( |
| 129 | torch.float16 |
| 130 | if training_args.fp16 |
| 131 | else (torch.bfloat16 if training_args.bf16 else torch.float32) |
| 132 | ) |
| 133 | |
| 134 | model = transformers.AutoModelForCausalLM.from_pretrained( |
| 135 | model_args.model_name_or_path, |
| 136 | cache_dir=training_args.cache_dir, |
| 137 | device_map=device_map, |
| 138 | quantization_config=BitsAndBytesConfig( |
| 139 | load_in_4bit=True, |
| 140 | bnb_4bit_use_double_quant=True, |
| 141 | bnb_4bit_quant_type="nf4", |
| 142 | bnb_4bit_compute_dtype=compute_dtype, |
| 143 | ) |
| 144 | if lora_args.q_lora |
| 145 | else None, |
| 146 | ) |
| 147 | lora_config = LoraConfig( |
| 148 | r=lora_args.lora_r, |
| 149 | lora_alpha=lora_args.lora_alpha, |
| 150 | target_modules=lora_args.lora_target_modules, |
| 151 | lora_dropout=lora_args.lora_dropout, |
| 152 | bias=lora_args.lora_bias, |
| 153 | task_type="CAUSAL_LM", |
| 154 | ) |
| 155 | |
| 156 | if lora_args.q_lora: |
| 157 | model = prepare_model_for_kbit_training( |
| 158 | model, use_gradient_checkpointing=training_args.gradient_checkpointing |
| 159 | ) |
| 160 | if not ddp and torch.cuda.device_count() > 1: |
| 161 | # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available |
no test coverage detected
searching dependent graphs…