加载模型
(args, training_args)
| 249 | |
| 250 | |
| 251 | def load_model(args, training_args): |
| 252 | """ |
| 253 | 加载模型 |
| 254 | """ |
| 255 | assert training_args.bf16 or training_args.fp16, 'bf16 or fp16 should be True' |
| 256 | logger.info(f'Loading model from base model: {args.model_name_or_path}') |
| 257 | logger.info(f'Train model with {args.train_mode}') |
| 258 | |
| 259 | # init model kwargs |
| 260 | # todo add flash attention |
| 261 | # attn_implementation = None |
| 262 | torch_dtype = torch.float16 if training_args.fp16 else torch.bfloat16 |
| 263 | if args.train_mode == 'qlora': |
| 264 | quantization_config = BitsAndBytesConfig( |
| 265 | load_in_4bit=True, |
| 266 | bnb_4bit_compute_dtype=torch.float16 if training_args.fp16 else torch.bfloat16, |
| 267 | bnb_4bit_use_double_quant=True, |
| 268 | bnb_4bit_quant_type="nf4", |
| 269 | llm_int8_threshold=6.0, |
| 270 | llm_int8_has_fp16_weight=False, |
| 271 | ) |
| 272 | else: |
| 273 | quantization_config = None |
| 274 | model_kwargs = dict( |
| 275 | trust_remote_code=True, |
| 276 | # attn_implementation=attn_implementation, |
| 277 | torch_dtype=torch_dtype, |
| 278 | use_cache=False if training_args.gradient_checkpointing else True, |
| 279 | device_map=get_kbit_device_map() if quantization_config is not None else None, |
| 280 | quantization_config=quantization_config, |
| 281 | ) |
| 282 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs) |
| 283 | |
| 284 | # moe模型,需要考虑负载均衡的loss |
| 285 | if 'output_router_logits' in model.config.to_dict(): |
| 286 | logger.info('set output_router_logits as True') |
| 287 | model.config.output_router_logits = True |
| 288 | # QLoRA: casts all the non int8 modules to full precision (fp32) for stability |
| 289 | if args.train_mode == 'qlora' and args.task_type in ['pretrain', 'sft']: |
| 290 | model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing) |
| 291 | # LoRA: Enables the gradients for the input embeddings |
| 292 | if args.train_mode == 'lora' and args.task_type in ['pretrain', 'sft']: |
| 293 | # For backward compatibility |
| 294 | if hasattr(model, "enable_input_require_grads"): |
| 295 | model.enable_input_require_grads() |
| 296 | else: |
| 297 | def make_inputs_require_grad(module, input, output): |
| 298 | output.requires_grad_(True) |
| 299 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) |
| 300 | |
| 301 | # init peft_config |
| 302 | if args.train_mode == 'full': |
| 303 | peft_config = None |
| 304 | else: |
| 305 | # 找到所有需要插入adapter的全连接层 |
| 306 | target_modules = find_all_linear_names(model, args.train_mode) |
| 307 | peft_config = LoraConfig( |
| 308 | r=args.lora_rank, |
no test coverage detected