| 642 | return backend_pb2.Result(success=True, message="OK") |
| 643 | |
| 644 | def ListCheckpoints(self, request, context): |
| 645 | output_dir = request.output_dir |
| 646 | if not os.path.isdir(output_dir): |
| 647 | return backend_pb2.ListCheckpointsResponse(checkpoints=[]) |
| 648 | |
| 649 | checkpoints = [] |
| 650 | for entry in sorted(os.listdir(output_dir)): |
| 651 | if entry.startswith("checkpoint-"): |
| 652 | ckpt_path = os.path.join(output_dir, entry) |
| 653 | if not os.path.isdir(ckpt_path): |
| 654 | continue |
| 655 | step = 0 |
| 656 | try: |
| 657 | step = int(entry.split("-")[1]) |
| 658 | except (IndexError, ValueError): |
| 659 | pass |
| 660 | |
| 661 | # Try to read trainer_state.json for metadata |
| 662 | loss = 0.0 |
| 663 | epoch = 0.0 |
| 664 | state_file = os.path.join(ckpt_path, "trainer_state.json") |
| 665 | if os.path.exists(state_file): |
| 666 | try: |
| 667 | with open(state_file) as f: |
| 668 | state = json.load(f) |
| 669 | if state.get("log_history"): |
| 670 | last_log = state["log_history"][-1] |
| 671 | loss = last_log.get("loss", 0.0) |
| 672 | epoch = last_log.get("epoch", 0.0) |
| 673 | except Exception: |
| 674 | pass |
| 675 | |
| 676 | created_at = time.strftime( |
| 677 | "%Y-%m-%dT%H:%M:%SZ", |
| 678 | time.gmtime(os.path.getmtime(ckpt_path)), |
| 679 | ) |
| 680 | |
| 681 | checkpoints.append(backend_pb2.CheckpointInfo( |
| 682 | path=ckpt_path, |
| 683 | step=step, |
| 684 | epoch=float(epoch), |
| 685 | loss=float(loss), |
| 686 | created_at=created_at, |
| 687 | )) |
| 688 | |
| 689 | return backend_pb2.ListCheckpointsResponse(checkpoints=checkpoints) |
| 690 | |
| 691 | def ExportModel(self, request, context): |
| 692 | export_format = request.export_format or "lora" |