MCPcopy
hub / github.com/mudler/LocalAI / get_callback

Method get_callback

backend/python/trl/backend.py:39–159  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

37 self.total_epochs = total_epochs
38
39 def get_callback(self):
40 from transformers import TrainerCallback
41
42 parent = self
43
44 class _Callback(TrainerCallback):
45 def __init__(self):
46 self._train_start_time = None
47
48 def on_train_begin(self, args, state, control, **kwargs):
49 self._train_start_time = time.time()
50
51 def on_log(self, args, state, control, logs=None, **kwargs):
52 if logs is None:
53 return
54 total_steps = state.max_steps if state.max_steps > 0 else 0
55 progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
56 eta = 0.0
57 if state.global_step > 0 and total_steps > 0 and self._train_start_time:
58 elapsed = time.time() - self._train_start_time
59 remaining_steps = total_steps - state.global_step
60 if state.global_step > 0:
61 eta = remaining_steps * (elapsed / state.global_step)
62
63 extra_metrics = {}
64 for k, v in logs.items():
65 if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'):
66 extra_metrics[k] = float(v)
67
68 update = backend_pb2.FineTuneProgressUpdate(
69 job_id=parent.job_id,
70 current_step=state.global_step,
71 total_steps=total_steps,
72 current_epoch=float(logs.get('epoch', 0)),
73 total_epochs=float(parent.total_epochs),
74 loss=float(logs.get('loss', 0)),
75 learning_rate=float(logs.get('learning_rate', 0)),
76 grad_norm=float(logs.get('grad_norm', 0)),
77 eval_loss=float(logs.get('eval_loss', 0)),
78 eta_seconds=float(eta),
79 progress_percent=float(progress),
80 status="training",
81 extra_metrics=extra_metrics,
82 )
83 parent.progress_queue.put(update)
84
85 def on_prediction_step(self, args, state, control, **kwargs):
86 """Send periodic updates during evaluation so the UI doesn't freeze."""
87 if not hasattr(self, '_eval_update_counter'):
88 self._eval_update_counter = 0
89 self._eval_update_counter += 1
90 # Throttle: send an update every 10 prediction steps
91 if self._eval_update_counter % 10 != 0:
92 return
93 total_steps = state.max_steps if state.max_steps > 0 else 0
94 progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0
95 update = backend_pb2.FineTuneProgressUpdate(
96 job_id=parent.job_id,

Callers 1

_do_trainingMethod · 0.95

Calls 1

_CallbackClass · 0.85

Tested by

no test coverage detected