MCPcopy
hub / github.com/microsoft/agent-lightning / sft_algorithm

Function sft_algorithm

examples/unsloth/sft_algorithm.py:332–379  ·  view source on GitHub ↗

Run the complete SFT algorithm with multiple iterations. This is the main entry point for running the SFT training pipeline. It sets up the LLM proxy, data adapter, and runs multiple iterations of model training. The function performs these steps for each iteration: 1. Serves the c

(*, store: LightningStore)

Source from the content-addressed store, hash-verified

330
331
332async def sft_algorithm(*, store: LightningStore) -> None:
333 """Run the complete SFT algorithm with multiple iterations.
334
335 This is the main entry point for running the SFT training pipeline. It sets up
336 the LLM proxy, data adapter, and runs multiple iterations of model training.
337
338 The function performs these steps for each iteration:
339 1. Serves the current model via vLLM
340 2. Collects rollout data using the model
341 3. Converts trace data to training triplets
342 4. Trains the model on top-performing examples
343 5. Saves the improved model for the next iteration
344
345 Args:
346 store: The LightningStore instance for managing rollouts and trace data.
347 """
348 train_dataset = load_math_dataset()
349
350 # Constants for the SFT algorithm
351 MAX_ITERATIONS = 2
352 VLLM_PORT = 12316
353 LLM_PROXY_PORT = 12358
354 TRAIN_TRIPLET_FRACTION = 0.5
355
356 # Download the model before starting the script:
357 # hf download unsloth/Qwen3-4B-Instruct-2507 --local-dir models/version_0
358 model_path = "models/version_0"
359
360 # Create the LLM proxy for rollout worker access and trace data collection
361 llm_proxy = LLMProxy(port=LLM_PROXY_PORT, store=store)
362
363 # This data adapter util is used to convert the trace data recorded by LLM proxy
364 # into a format suitable for SFT
365 data_adapter = LlmProxyTraceToTriplet()
366
367 for iteration in range(MAX_ITERATIONS):
368 model_path = await sft_one_iter(
369 iteration=iteration,
370 store=store,
371 model_path=model_path,
372 train_dataset=train_dataset,
373 llm_proxy=llm_proxy,
374 data_adapter=data_adapter,
375 triplet_fraction=TRAIN_TRIPLET_FRACTION,
376 vllm_port=VLLM_PORT,
377 )
378
379 console.print(f"[bold red][Algo][/bold red] Final model path: {model_path}")
380
381
382if __name__ == "__main__":

Callers 1

sft_algorithm.pyFile · 0.85

Calls 4

load_math_datasetFunction · 0.90
LLMProxyClass · 0.90
sft_one_iterFunction · 0.85

Tested by

no test coverage detected