MCPcopy
hub / github.com/microsoft/agent-lightning / sft_one_iter

Function sft_one_iter

examples/unsloth/sft_algorithm.py:135–329  ·  view source on GitHub ↗

One iteration of SFT. The idea is to get all trace data from the rollouts, and then use the reward to select the top triplets to train on. Performs (1) rollout - data collection, (2) data conversion, (3) SFT training, and (4) model saving. Args: iteration: The iteration number

(
    *,
    iteration: int,
    store: LightningStore,
    model_path: str,
    train_dataset: Dataset[GsmProblem],
    llm_proxy: LLMProxy,
    data_adapter: TraceToTripletBase,
    triplet_fraction: float,
    vllm_port: int,
)

Source from the content-addressed store, hash-verified

133
134
135async def sft_one_iter(
136 *,
137 iteration: int,
138 store: LightningStore,
139 model_path: str,
140 train_dataset: Dataset[GsmProblem],
141 llm_proxy: LLMProxy,
142 data_adapter: TraceToTripletBase,
143 triplet_fraction: float,
144 vllm_port: int,
145) -> str:
146 """One iteration of SFT.
147
148 The idea is to get all trace data from the rollouts, and then use the reward to select the top triplets to train on.
149
150 Performs (1) rollout - data collection, (2) data conversion, (3) SFT training, and (4) model saving.
151
152 Args:
153 iteration: The iteration number.
154 store: The LightningStore instance.
155 model_path: The path to the model to train. Must be a local path.
156 train_dataset: The dataset to train on.
157 llm_proxy: The LLM proxy instance. Used to shield between the inference endpoint and the rollout runners.
158 data_adapter: The data adapter instance. This is used to convert the trace data recorded by LLM proxy.
159 triplet_fraction: The fraction of triplets to use for SFT.
160 vllm_port: The port to serve vLLM chat completion endpoint.
161
162 Returns:
163 The path to the saved model (next generation).
164 """
165
166 console.print(f"\n[bold red][Algo][/bold red] Starting iteration {iteration}")
167
168 # 1. Rollout to get trace data
169 if not os.path.exists(model_path):
170 raise ValueError(f"Model path {model_path} does not exist.")
171
172 # First launch the vLLM server
173 with vllm_server(model_path, vllm_port) as server_address:
174 # Update the model list of the LLM proxy and start it
175 model_list: List[ModelConfig] = [
176 {
177 "model_name": "Qwen3-4B-Instruct",
178 "litellm_params": {
179 "model": f"hosted_vllm/{model_path}",
180 "api_base": server_address,
181 },
182 }
183 ]
184 console.print(f"[bold red][Algo][/bold red] Updating model list and restarting LLM proxy: {model_list}")
185 llm_proxy.update_model_list(model_list)
186 # Restart the LLM proxy after backend model list update
187 # If LLM proxy has never been started, it will be started
188 await llm_proxy.restart()
189
190 # Put the LLM proxy address into the store as an address
191 resources_update = await store.add_resources(
192 {

Callers 2

runMethod · 0.90
sft_algorithmFunction · 0.85

Calls 11

update_model_listMethod · 0.80
restartMethod · 0.80
as_resourceMethod · 0.80
vllm_serverFunction · 0.70
add_resourcesMethod · 0.45
enqueue_rolloutMethod · 0.45
wait_for_rolloutsMethod · 0.45
query_spansMethod · 0.45
adaptMethod · 0.45
getMethod · 0.45
startMethod · 0.45

Tested by

no test coverage detected