One iteration of SFT. The idea is to get all trace data from the rollouts, and then use the reward to select the top triplets to train on. Performs (1) rollout - data collection, (2) data conversion, (3) SFT training, and (4) model saving. Args: iteration: The iteration number
(
*,
iteration: int,
store: LightningStore,
model_path: str,
train_dataset: Dataset[GsmProblem],
llm_proxy: LLMProxy,
data_adapter: TraceToTripletBase,
triplet_fraction: float,
vllm_port: int,
)
| 133 | |
| 134 | |
| 135 | async def sft_one_iter( |
| 136 | *, |
| 137 | iteration: int, |
| 138 | store: LightningStore, |
| 139 | model_path: str, |
| 140 | train_dataset: Dataset[GsmProblem], |
| 141 | llm_proxy: LLMProxy, |
| 142 | data_adapter: TraceToTripletBase, |
| 143 | triplet_fraction: float, |
| 144 | vllm_port: int, |
| 145 | ) -> str: |
| 146 | """One iteration of SFT. |
| 147 | |
| 148 | The idea is to get all trace data from the rollouts, and then use the reward to select the top triplets to train on. |
| 149 | |
| 150 | Performs (1) rollout - data collection, (2) data conversion, (3) SFT training, and (4) model saving. |
| 151 | |
| 152 | Args: |
| 153 | iteration: The iteration number. |
| 154 | store: The LightningStore instance. |
| 155 | model_path: The path to the model to train. Must be a local path. |
| 156 | train_dataset: The dataset to train on. |
| 157 | llm_proxy: The LLM proxy instance. Used to shield between the inference endpoint and the rollout runners. |
| 158 | data_adapter: The data adapter instance. This is used to convert the trace data recorded by LLM proxy. |
| 159 | triplet_fraction: The fraction of triplets to use for SFT. |
| 160 | vllm_port: The port to serve vLLM chat completion endpoint. |
| 161 | |
| 162 | Returns: |
| 163 | The path to the saved model (next generation). |
| 164 | """ |
| 165 | |
| 166 | console.print(f"\n[bold red][Algo][/bold red] Starting iteration {iteration}") |
| 167 | |
| 168 | # 1. Rollout to get trace data |
| 169 | if not os.path.exists(model_path): |
| 170 | raise ValueError(f"Model path {model_path} does not exist.") |
| 171 | |
| 172 | # First launch the vLLM server |
| 173 | with vllm_server(model_path, vllm_port) as server_address: |
| 174 | # Update the model list of the LLM proxy and start it |
| 175 | model_list: List[ModelConfig] = [ |
| 176 | { |
| 177 | "model_name": "Qwen3-4B-Instruct", |
| 178 | "litellm_params": { |
| 179 | "model": f"hosted_vllm/{model_path}", |
| 180 | "api_base": server_address, |
| 181 | }, |
| 182 | } |
| 183 | ] |
| 184 | console.print(f"[bold red][Algo][/bold red] Updating model list and restarting LLM proxy: {model_list}") |
| 185 | llm_proxy.update_model_list(model_list) |
| 186 | # Restart the LLM proxy after backend model list update |
| 187 | # If LLM proxy has never been started, it will be started |
| 188 | await llm_proxy.restart() |
| 189 | |
| 190 | # Put the LLM proxy address into the store as an address |
| 191 | resources_update = await store.add_resources( |
| 192 | { |
no test coverage detected