| 682 | progress_q = job.progress_queue |
| 683 | |
| 684 | class QueuedTrainer(Trainer): |
| 685 | def log(self_, model_output): |
| 686 | if self_.step > 0 and self_.step % self_.logging_interval == 0: |
| 687 | try: |
| 688 | loss = self_.accelerator.reduce( |
| 689 | model_output.loss.detach(), reduction="mean" |
| 690 | ).item() |
| 691 | except Exception: |
| 692 | loss = float("nan") |
| 693 | lr_now = self_.optimizer.param_groups[0]["lr"] |
| 694 | pct = (self_.step / self_.max_steps * 100.0) if self_.max_steps else 0.0 |
| 695 | progress_q.put(backend_pb2.FineTuneProgressUpdate( |
| 696 | job_id=job.job_id, |
| 697 | current_step=int(self_.step), |
| 698 | total_steps=int(self_.max_steps), |
| 699 | current_epoch=float(self_.epoch), |
| 700 | loss=float(loss), |
| 701 | learning_rate=float(lr_now), |
| 702 | progress_percent=float(pct), |
| 703 | status="training", |
| 704 | )) |
| 705 | # Honour stop requests: raising here terminates the loop cleanly |
| 706 | if job.stopped: |
| 707 | raise KeyboardInterrupt("stop requested") |
| 708 | return super().log(model_output) |
| 709 | |
| 710 | def validate(self_): |
| 711 | progress_q.put(backend_pb2.FineTuneProgressUpdate( |
| 712 | job_id=job.job_id, current_step=int(self_.step), |
| 713 | total_steps=int(self_.max_steps), status="training", |
| 714 | message=f"Running validation at step {self_.step}", |
| 715 | )) |
| 716 | return super().validate() |
| 717 | |
| 718 | trainer = QueuedTrainer( |
| 719 | model_id=model_id, |