Run the complete SFT algorithm with multiple iterations. This is the main entry point for running the SFT training pipeline. It sets up the LLM proxy, data adapter, and runs multiple iterations of model training. The function performs these steps for each iteration: 1. Serves the c
(*, store: LightningStore)
| 330 | |
| 331 | |
| 332 | async def sft_algorithm(*, store: LightningStore) -> None: |
| 333 | """Run the complete SFT algorithm with multiple iterations. |
| 334 | |
| 335 | This is the main entry point for running the SFT training pipeline. It sets up |
| 336 | the LLM proxy, data adapter, and runs multiple iterations of model training. |
| 337 | |
| 338 | The function performs these steps for each iteration: |
| 339 | 1. Serves the current model via vLLM |
| 340 | 2. Collects rollout data using the model |
| 341 | 3. Converts trace data to training triplets |
| 342 | 4. Trains the model on top-performing examples |
| 343 | 5. Saves the improved model for the next iteration |
| 344 | |
| 345 | Args: |
| 346 | store: The LightningStore instance for managing rollouts and trace data. |
| 347 | """ |
| 348 | train_dataset = load_math_dataset() |
| 349 | |
| 350 | # Constants for the SFT algorithm |
| 351 | MAX_ITERATIONS = 2 |
| 352 | VLLM_PORT = 12316 |
| 353 | LLM_PROXY_PORT = 12358 |
| 354 | TRAIN_TRIPLET_FRACTION = 0.5 |
| 355 | |
| 356 | # Download the model before starting the script: |
| 357 | # hf download unsloth/Qwen3-4B-Instruct-2507 --local-dir models/version_0 |
| 358 | model_path = "models/version_0" |
| 359 | |
| 360 | # Create the LLM proxy for rollout worker access and trace data collection |
| 361 | llm_proxy = LLMProxy(port=LLM_PROXY_PORT, store=store) |
| 362 | |
| 363 | # This data adapter util is used to convert the trace data recorded by LLM proxy |
| 364 | # into a format suitable for SFT |
| 365 | data_adapter = LlmProxyTraceToTriplet() |
| 366 | |
| 367 | for iteration in range(MAX_ITERATIONS): |
| 368 | model_path = await sft_one_iter( |
| 369 | iteration=iteration, |
| 370 | store=store, |
| 371 | model_path=model_path, |
| 372 | train_dataset=train_dataset, |
| 373 | llm_proxy=llm_proxy, |
| 374 | data_adapter=data_adapter, |
| 375 | triplet_fraction=TRAIN_TRIPLET_FRACTION, |
| 376 | vllm_port=VLLM_PORT, |
| 377 | ) |
| 378 | |
| 379 | console.print(f"[bold red][Algo][/bold red] Final model path: {model_path}") |
| 380 | |
| 381 | |
| 382 | if __name__ == "__main__": |
no test coverage detected