HuggingFace TrainerCallback that pushes progress updates to a queue.
| 29 | |
| 30 | |
| 31 | class ProgressCallback: |
| 32 | """HuggingFace TrainerCallback that pushes progress updates to a queue.""" |
| 33 | |
| 34 | def __init__(self, job_id, progress_queue, total_epochs): |
| 35 | self.job_id = job_id |
| 36 | self.progress_queue = progress_queue |
| 37 | self.total_epochs = total_epochs |
| 38 | |
| 39 | def get_callback(self): |
| 40 | from transformers import TrainerCallback |
| 41 | |
| 42 | parent = self |
| 43 | |
| 44 | class _Callback(TrainerCallback): |
| 45 | def __init__(self): |
| 46 | self._train_start_time = None |
| 47 | |
| 48 | def on_train_begin(self, args, state, control, **kwargs): |
| 49 | self._train_start_time = time.time() |
| 50 | |
| 51 | def on_log(self, args, state, control, logs=None, **kwargs): |
| 52 | if logs is None: |
| 53 | return |
| 54 | total_steps = state.max_steps if state.max_steps > 0 else 0 |
| 55 | progress = (state.global_step / total_steps * 100) if total_steps > 0 else 0 |
| 56 | eta = 0.0 |
| 57 | if state.global_step > 0 and total_steps > 0 and self._train_start_time: |
| 58 | elapsed = time.time() - self._train_start_time |
| 59 | remaining_steps = total_steps - state.global_step |
| 60 | if state.global_step > 0: |
| 61 | eta = remaining_steps * (elapsed / state.global_step) |
| 62 | |
| 63 | extra_metrics = {} |
| 64 | for k, v in logs.items(): |
| 65 | if isinstance(v, (int, float)) and k not in ('loss', 'learning_rate', 'epoch', 'grad_norm', 'eval_loss'): |
| 66 | extra_metrics[k] = float(v) |
| 67 | |
| 68 | update = backend_pb2.FineTuneProgressUpdate( |
| 69 | job_id=parent.job_id, |
| 70 | current_step=state.global_step, |
| 71 | total_steps=total_steps, |
| 72 | current_epoch=float(logs.get('epoch', 0)), |
| 73 | total_epochs=float(parent.total_epochs), |
| 74 | loss=float(logs.get('loss', 0)), |
| 75 | learning_rate=float(logs.get('learning_rate', 0)), |
| 76 | grad_norm=float(logs.get('grad_norm', 0)), |
| 77 | eval_loss=float(logs.get('eval_loss', 0)), |
| 78 | eta_seconds=float(eta), |
| 79 | progress_percent=float(progress), |
| 80 | status="training", |
| 81 | extra_metrics=extra_metrics, |
| 82 | ) |
| 83 | parent.progress_queue.put(update) |
| 84 | |
| 85 | def on_prediction_step(self, args, state, control, **kwargs): |
| 86 | """Send periodic updates during evaluation so the UI doesn't freeze.""" |
| 87 | if not hasattr(self, '_eval_update_counter'): |
| 88 | self._eval_update_counter = 0 |