MCPcopy
hub / github.com/deepseek-ai/DeepSeek-MoE / build_model

Function build_model

finetune/finetune.py:182–248  ·  view source on GitHub ↗
(model_args, training_args, checkpoint_dir)

Source from the content-addressed store, hash-verified

180 return data_dict
181
182def build_model(model_args, training_args, checkpoint_dir):
183 if not model_args.use_lora: assert model_args.bits in [16, 32]
184 compute_dtype = (torch.bfloat16 if training_args.bf16 else torch.float16)
185 model = transformers.AutoModelForCausalLM.from_pretrained(
186 model_args.model_name_or_path,
187 load_in_4bit=model_args.bits == 4,
188 load_in_8bit=model_args.bits == 8,
189 quantization_config=BitsAndBytesConfig(
190 load_in_4bit=model_args.bits == 4,
191 load_in_8bit=model_args.bits == 8,
192 llm_int8_threshold=6.0,
193 llm_int8_has_fp16_weight=False,
194 bnb_4bit_compute_dtype=compute_dtype,
195 bnb_4bit_use_double_quant=model_args.double_quant,
196 bnb_4bit_quant_type=model_args.quant_type,
197 ) if model_args.use_lora else None,
198 torch_dtype=compute_dtype,
199 trust_remote_code=True,
200 )
201
202 if compute_dtype == torch.float16 and model_args.bits == 4:
203 if torch.cuda.is_bf16_supported():
204 logger.info('='*80)
205 logger.info('Your GPU supports bfloat16, you can accelerate training with the argument --bf16')
206 logger.info('='*80)
207 setattr(model, 'model_parallel', True)
208 setattr(model, 'is_parallelizable', True)
209 model.config.torch_dtype=torch.bfloat16 if training_args.bf16 else torch.float32
210 # Tokenizer
211
212 if model_args.use_lora and model_args.bits < 16:
213 model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
214
215 if model_args.use_lora:
216 if checkpoint_dir is not None:
217 logger.info(f"Loading adapters from {checkpoint_dir}.")
218 # os.path.join(checkpoint_dir, 'adapter_model')
219 model = PeftModel.from_pretrained(model, checkpoint_dir, is_trainable=True)
220 else:
221 logger.info(f'Init LoRA modules...')
222 target_modules = model_args.trainable.split(',')
223 modules_to_save = model_args.modules_to_save
224 if modules_to_save is not None:
225 modules_to_save = modules_to_save.split(',')
226 lora_rank = model_args.lora_rank
227 lora_dropout = model_args.lora_dropout
228 lora_alpha = model_args.lora_alpha
229 peft_config = LoraConfig(
230 task_type=TaskType.CAUSAL_LM,
231 target_modules=target_modules,
232 inference_mode=False,
233 r=lora_rank, lora_alpha=lora_alpha,
234 lora_dropout=lora_dropout,
235 modules_to_save=modules_to_save)
236 model = get_peft_model(model, peft_config)
237
238 for name, module in model.named_modules():
239 if isinstance(module, LoraLayer):

Callers 1

trainFunction · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected