MCPcopy
hub / github.com/yangjianxin1/Firefly / load_model

Function load_model

train.py:251–337  ·  view source on GitHub ↗

加载模型

(args, training_args)

Source from the content-addressed store, hash-verified

249
250
251def load_model(args, training_args):
252 """
253 加载模型
254 """
255 assert training_args.bf16 or training_args.fp16, 'bf16 or fp16 should be True'
256 logger.info(f'Loading model from base model: {args.model_name_or_path}')
257 logger.info(f'Train model with {args.train_mode}')
258
259 # init model kwargs
260 # todo add flash attention
261 # attn_implementation = None
262 torch_dtype = torch.float16 if training_args.fp16 else torch.bfloat16
263 if args.train_mode == 'qlora':
264 quantization_config = BitsAndBytesConfig(
265 load_in_4bit=True,
266 bnb_4bit_compute_dtype=torch.float16 if training_args.fp16 else torch.bfloat16,
267 bnb_4bit_use_double_quant=True,
268 bnb_4bit_quant_type="nf4",
269 llm_int8_threshold=6.0,
270 llm_int8_has_fp16_weight=False,
271 )
272 else:
273 quantization_config = None
274 model_kwargs = dict(
275 trust_remote_code=True,
276 # attn_implementation=attn_implementation,
277 torch_dtype=torch_dtype,
278 use_cache=False if training_args.gradient_checkpointing else True,
279 device_map=get_kbit_device_map() if quantization_config is not None else None,
280 quantization_config=quantization_config,
281 )
282 model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, **model_kwargs)
283
284 # moe模型,需要考虑负载均衡的loss
285 if 'output_router_logits' in model.config.to_dict():
286 logger.info('set output_router_logits as True')
287 model.config.output_router_logits = True
288 # QLoRA: casts all the non int8 modules to full precision (fp32) for stability
289 if args.train_mode == 'qlora' and args.task_type in ['pretrain', 'sft']:
290 model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=training_args.gradient_checkpointing)
291 # LoRA: Enables the gradients for the input embeddings
292 if args.train_mode == 'lora' and args.task_type in ['pretrain', 'sft']:
293 # For backward compatibility
294 if hasattr(model, "enable_input_require_grads"):
295 model.enable_input_require_grads()
296 else:
297 def make_inputs_require_grad(module, input, output):
298 output.requires_grad_(True)
299 model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
300
301 # init peft_config
302 if args.train_mode == 'full':
303 peft_config = None
304 else:
305 # 找到所有需要插入adapter的全连接层
306 target_modules = find_all_linear_names(model, args.train_mode)
307 peft_config = LoraConfig(
308 r=args.lora_rank,

Callers 1

init_componentsFunction · 0.85

Calls 1

find_all_linear_namesFunction · 0.85

Tested by

no test coverage detected