MCPcopy
hub / github.com/microsoft/agent-lightning / do_sync_training

Function do_sync_training

examples/tinker/agl_tinker/train.py:111–224  ·  view source on GitHub ↗

Implements fully synchronous on-policy training. See `tinker_cookbook.rl.train.do_sync_training` for the original flow. The Agent-lightning adaptation diverges in a few places: * A LiteLLM proxy is restarted every batch so refreshed player checkpoints are immediately visible to r

(
    *,
    start_batch: int,
    end_batch: int,
    num_batches: int,
    cfg: Config,
    training_client: tinker.TrainingClient,
    service_client: tinker.ServiceClient,
    evaluators: list[AGLTestSetEvaluator[Any]],
    dataset: AGLDataset[Any],
    ml_logger: ml_log.Logger,
    tokenizer: Tokenizer,
    store: LightningStore,
    adapter: TraceToTripletBase,
    llm_proxy: LLMProxy,
)

Source from the content-addressed store, hash-verified

109
110@scope
111async def do_sync_training(
112 *,
113 start_batch: int,
114 end_batch: int,
115 num_batches: int,
116 cfg: Config,
117 training_client: tinker.TrainingClient,
118 service_client: tinker.ServiceClient,
119 evaluators: list[AGLTestSetEvaluator[Any]],
120 dataset: AGLDataset[Any],
121 ml_logger: ml_log.Logger,
122 tokenizer: Tokenizer,
123 store: LightningStore,
124 adapter: TraceToTripletBase,
125 llm_proxy: LLMProxy,
126):
127 """Implements fully synchronous on-policy training.
128
129 See `tinker_cookbook.rl.train.do_sync_training` for the original flow. The
130 Agent-lightning adaptation diverges in a few places:
131
132 * A LiteLLM proxy is restarted every batch so refreshed player checkpoints
133 are immediately visible to rollout workers.
134 * Trajectories are gathered via `do_group_of_group_rollouts`, which in turn
135 dequeues tasks from the Agent-lightning store and rebuilds transitions from
136 trace triplets.
137 * Evaluation hooks call `AGLTestSetEvaluator` so validation samples reuse the
138 same CrewAI-based agent rather than invoking a raw token completer.
139 """
140
141 # Initial sampling client
142 logger.info(f"Creating sampling client with training client {training_client} and start batch {start_batch}")
143 sampling_client, _ = await save_checkpoint_and_get_sampling_client(
144 training_client, start_batch, cfg.log_path, cfg.save_every
145 )
146 logger.info(f"Creating renderer with name {cfg.renderer_name}")
147 renderer = get_renderer(cfg.renderer_name, tokenizer)
148
149 tinker_llm = TinkerLLM(
150 model_name=cfg.model_name,
151 sampling_client=sampling_client,
152 renderer=renderer,
153 tokenizer=tokenizer,
154 max_tokens=cfg.max_tokens,
155 temperature=cfg.train_temperature,
156 top_k=cfg.top_k,
157 top_p=cfg.top_p,
158 ).rewrite_litellm_custom_providers()
159
160 logger.info(f"Starting training from batch {start_batch} to {end_batch}")
161 for i_batch in range(start_batch, end_batch):
162 metrics = {
163 "progress/batch": i_batch,
164 "optim/lr": cfg.learning_rate,
165 "progress/done_frac": (i_batch + 1) / num_batches,
166 }
167 logger.info(f"[Batch {i_batch}] Starting training step. Learning rate: {cfg.learning_rate}")
168 t_start = time.time()

Callers 1

main_training_loopFunction · 0.85

Calls 13

TinkerLLMClass · 0.85
timeMethod · 0.80
update_model_listMethod · 0.80
as_model_listMethod · 0.80
restartMethod · 0.80
as_resourceMethod · 0.80
get_batchMethod · 0.80
add_resourcesMethod · 0.45
updateMethod · 0.45

Tested by

no test coverage detected