MCPcopy
hub / github.com/mudler/LocalAI / _do_training

Method _do_training

backend/python/trl/backend.py:248–616  ·  view source on GitHub ↗
(self, request, job)

Source from the content-addressed store, hash-verified

246 job.progress_queue.put(None)
247
248 def _do_training(self, request, job):
249 import torch
250 from transformers import AutoModelForCausalLM, AutoTokenizer
251 from datasets import load_dataset, Dataset
252
253 extra = dict(request.extra_options)
254 training_method = request.training_method or "sft"
255 training_type = request.training_type or "lora"
256
257 # Send loading status
258 job.progress_queue.put(backend_pb2.FineTuneProgressUpdate(
259 job_id=job.job_id, status="loading_model", message=f"Loading model {request.model}",
260 ))
261
262 # Determine device and dtype
263 device_map = "auto" if torch.cuda.is_available() else "cpu"
264 dtype = torch.float32 if not torch.cuda.is_available() else torch.bfloat16
265
266 # HuggingFace token for gated repos (from extra_options or HF_TOKEN env)
267 hf_token = extra.get("hf_token") or os.environ.get("HF_TOKEN")
268
269 # Load model
270 model_kwargs = {"device_map": device_map, "torch_dtype": dtype}
271 if hf_token:
272 model_kwargs["token"] = hf_token
273 if extra.get("trust_remote_code", "false").lower() == "true":
274 model_kwargs["trust_remote_code"] = True
275 if extra.get("load_in_4bit", "false").lower() == "true" and torch.cuda.is_available():
276 from transformers import BitsAndBytesConfig
277 model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True)
278
279 model = AutoModelForCausalLM.from_pretrained(request.model, **model_kwargs)
280 tokenizer = AutoTokenizer.from_pretrained(request.model, token=hf_token)
281 if tokenizer.pad_token is None:
282 tokenizer.pad_token = tokenizer.eos_token
283
284 job.model = model
285 job.tokenizer = tokenizer
286
287 # Apply LoRA if requested
288 if training_type == "lora":
289 from peft import LoraConfig, get_peft_model
290 lora_r = request.adapter_rank if request.adapter_rank > 0 else 16
291 lora_alpha = request.adapter_alpha if request.adapter_alpha > 0 else 16
292 lora_dropout = request.adapter_dropout if request.adapter_dropout > 0 else 0.0
293
294 target_modules = list(request.target_modules) if request.target_modules else None
295 peft_config = LoraConfig(
296 r=lora_r,
297 lora_alpha=lora_alpha,
298 lora_dropout=lora_dropout,
299 target_modules=target_modules or "all-linear",
300 bias="none",
301 task_type="CAUSAL_LM",
302 )
303 model = get_peft_model(model, peft_config)
304
305 # Load dataset

Callers 2

_run_trainingMethod · 0.95

Calls 5

get_callbackMethod · 0.95
build_reward_functionsFunction · 0.90
ProgressCallbackClass · 0.70
putMethod · 0.45
getMethod · 0.45