MCPcopy
hub / github.com/microsoft/agent-lightning / _train_step

Method _train_step

agentlightning/verl/trainer.py:239–436  ·  view source on GitHub ↗
(self, batch_dict: dict)

Source from the content-addressed store, hash-verified

237 return ref_worker.compute_ref_log_prob(batch)
238
239 def _train_step(self, batch_dict: dict) -> dict:
240 # Isolate in a separate method to automatically recycle the variables before validation.
241 batch: DataProto = DataProto.from_single_dict(batch_dict)
242 metrics = {}
243 timing_raw = {}
244
245 with _timer("step", timing_raw):
246
247 # When agent mode is enabled, we read the batch as it is.
248 gen_batch = batch
249
250 # generate a batch
251 with _timer("gen", timing_raw):
252 self.async_rollout_manager.wake_up()
253 self.agent_mode_daemon.set_up_data_and_server(
254 gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses
255 )
256 self.agent_mode_daemon.run_until_all_finished()
257 batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch(
258 max_prompt_length=(
259 self.config.agentlightning.trace_aggregator.trajectory_max_prompt_length
260 if self.config.agentlightning.trace_aggregator.level.startswith("trajectory")
261 else self.config.data.max_prompt_length
262 ),
263 max_response_length=(
264 self.config.agentlightning.trace_aggregator.trajectory_max_response_length
265 if self.config.agentlightning.trace_aggregator.level.startswith("trajectory")
266 else self.config.data.max_response_length
267 ),
268 device=gen_batch.batch["fake_ids"].device,
269 global_steps=self.global_steps,
270 )
271 metrics.update(agent_metrics)
272 self.agent_mode_daemon.clear_data_and_server()
273 self.async_rollout_manager.sleep()
274
275 if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
276 with _timer("gen_max", timing_raw):
277 gen_baseline_batch = deepcopy(gen_batch)
278 gen_baseline_batch.meta_info["do_sample"] = False
279 gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
280
281 batch = batch.union(gen_baseline_output)
282 reward_baseline_tensor = self.reward_fn(batch)
283 reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
284
285 batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
286
287 batch.batch["reward_baselines"] = reward_baseline_tensor
288
289 del gen_baseline_batch, gen_baseline_output
290
291 # uid is used for algorithm like GRPO, should be aligned to data id
292 batch.non_tensor_batch["uid"] = batch.non_tensor_batch["data_id_list"]
293
294 if "response_mask" not in batch.batch:
295 batch.batch["response_mask"] = compute_response_mask(batch)
296

Callers 1

fitMethod · 0.95

Calls 10

_timerFunction · 0.85
compute_data_metricsFunction · 0.85
get_train_data_batchMethod · 0.80
clear_data_and_serverMethod · 0.80
updateMethod · 0.45
popMethod · 0.45
getMethod · 0.45

Tested by

no test coverage detected