Implements fully synchronous on-policy training. See `tinker_cookbook.rl.train.do_sync_training` for the original flow. The Agent-lightning adaptation diverges in a few places: * A LiteLLM proxy is restarted every batch so refreshed player checkpoints are immediately visible to r
(
*,
start_batch: int,
end_batch: int,
num_batches: int,
cfg: Config,
training_client: tinker.TrainingClient,
service_client: tinker.ServiceClient,
evaluators: list[AGLTestSetEvaluator[Any]],
dataset: AGLDataset[Any],
ml_logger: ml_log.Logger,
tokenizer: Tokenizer,
store: LightningStore,
adapter: TraceToTripletBase,
llm_proxy: LLMProxy,
)
| 109 | |
| 110 | @scope |
| 111 | async def do_sync_training( |
| 112 | *, |
| 113 | start_batch: int, |
| 114 | end_batch: int, |
| 115 | num_batches: int, |
| 116 | cfg: Config, |
| 117 | training_client: tinker.TrainingClient, |
| 118 | service_client: tinker.ServiceClient, |
| 119 | evaluators: list[AGLTestSetEvaluator[Any]], |
| 120 | dataset: AGLDataset[Any], |
| 121 | ml_logger: ml_log.Logger, |
| 122 | tokenizer: Tokenizer, |
| 123 | store: LightningStore, |
| 124 | adapter: TraceToTripletBase, |
| 125 | llm_proxy: LLMProxy, |
| 126 | ): |
| 127 | """Implements fully synchronous on-policy training. |
| 128 | |
| 129 | See `tinker_cookbook.rl.train.do_sync_training` for the original flow. The |
| 130 | Agent-lightning adaptation diverges in a few places: |
| 131 | |
| 132 | * A LiteLLM proxy is restarted every batch so refreshed player checkpoints |
| 133 | are immediately visible to rollout workers. |
| 134 | * Trajectories are gathered via `do_group_of_group_rollouts`, which in turn |
| 135 | dequeues tasks from the Agent-lightning store and rebuilds transitions from |
| 136 | trace triplets. |
| 137 | * Evaluation hooks call `AGLTestSetEvaluator` so validation samples reuse the |
| 138 | same CrewAI-based agent rather than invoking a raw token completer. |
| 139 | """ |
| 140 | |
| 141 | # Initial sampling client |
| 142 | logger.info(f"Creating sampling client with training client {training_client} and start batch {start_batch}") |
| 143 | sampling_client, _ = await save_checkpoint_and_get_sampling_client( |
| 144 | training_client, start_batch, cfg.log_path, cfg.save_every |
| 145 | ) |
| 146 | logger.info(f"Creating renderer with name {cfg.renderer_name}") |
| 147 | renderer = get_renderer(cfg.renderer_name, tokenizer) |
| 148 | |
| 149 | tinker_llm = TinkerLLM( |
| 150 | model_name=cfg.model_name, |
| 151 | sampling_client=sampling_client, |
| 152 | renderer=renderer, |
| 153 | tokenizer=tokenizer, |
| 154 | max_tokens=cfg.max_tokens, |
| 155 | temperature=cfg.train_temperature, |
| 156 | top_k=cfg.top_k, |
| 157 | top_p=cfg.top_p, |
| 158 | ).rewrite_litellm_custom_providers() |
| 159 | |
| 160 | logger.info(f"Starting training from batch {start_batch} to {end_batch}") |
| 161 | for i_batch in range(start_batch, end_batch): |
| 162 | metrics = { |
| 163 | "progress/batch": i_batch, |
| 164 | "optim/lr": cfg.learning_rate, |
| 165 | "progress/done_frac": (i_batch + 1) / num_batches, |
| 166 | } |
| 167 | logger.info(f"[Batch {i_batch}] Starting training step. Learning rate: {cfg.learning_rate}") |
| 168 | t_start = time.time() |
no test coverage detected