MCPcopy
hub / github.com/lm-sys/FastChat / train

Function train

fastchat/train/train_lora.py:104–218  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

102
103
104def train():
105 parser = transformers.HfArgumentParser(
106 (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
107 )
108 (
109 model_args,
110 data_args,
111 training_args,
112 lora_args,
113 ) = parser.parse_args_into_dataclasses()
114
115 if training_args.flash_attn:
116 replace_llama_attn_with_flash_attn()
117
118 device_map = None
119 world_size = int(os.environ.get("WORLD_SIZE", 1))
120 ddp = world_size != 1
121 if lora_args.q_lora:
122 device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
123 if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
124 logging.warning(
125 "FSDP and ZeRO3 are both currently incompatible with QLoRA."
126 )
127
128 compute_dtype = (
129 torch.float16
130 if training_args.fp16
131 else (torch.bfloat16 if training_args.bf16 else torch.float32)
132 )
133
134 model = transformers.AutoModelForCausalLM.from_pretrained(
135 model_args.model_name_or_path,
136 cache_dir=training_args.cache_dir,
137 device_map=device_map,
138 quantization_config=BitsAndBytesConfig(
139 load_in_4bit=True,
140 bnb_4bit_use_double_quant=True,
141 bnb_4bit_quant_type="nf4",
142 bnb_4bit_compute_dtype=compute_dtype,
143 )
144 if lora_args.q_lora
145 else None,
146 )
147 lora_config = LoraConfig(
148 r=lora_args.lora_r,
149 lora_alpha=lora_args.lora_alpha,
150 target_modules=lora_args.lora_target_modules,
151 lora_dropout=lora_args.lora_dropout,
152 bias=lora_args.lora_bias,
153 task_type="CAUSAL_LM",
154 )
155
156 if lora_args.q_lora:
157 model = prepare_model_for_kbit_training(
158 model, use_gradient_checkpointing=training_args.gradient_checkpointing
159 )
160 if not ddp and torch.cuda.device_count() > 1:
161 # keeps Trainer from trying its own DataParallelism when more than 1 gpu is available

Callers 1

train_lora.pyFile · 0.70

Calls 4

toMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…