(self, batch_dict: dict)
| 237 | return ref_worker.compute_ref_log_prob(batch) |
| 238 | |
| 239 | def _train_step(self, batch_dict: dict) -> dict: |
| 240 | # Isolate in a separate method to automatically recycle the variables before validation. |
| 241 | batch: DataProto = DataProto.from_single_dict(batch_dict) |
| 242 | metrics = {} |
| 243 | timing_raw = {} |
| 244 | |
| 245 | with _timer("step", timing_raw): |
| 246 | |
| 247 | # When agent mode is enabled, we read the batch as it is. |
| 248 | gen_batch = batch |
| 249 | |
| 250 | # generate a batch |
| 251 | with _timer("gen", timing_raw): |
| 252 | self.async_rollout_manager.wake_up() |
| 253 | self.agent_mode_daemon.set_up_data_and_server( |
| 254 | gen_batch.non_tensor_batch, self.async_rollout_manager.server_addresses |
| 255 | ) |
| 256 | self.agent_mode_daemon.run_until_all_finished() |
| 257 | batch, agent_metrics = self.agent_mode_daemon.get_train_data_batch( |
| 258 | max_prompt_length=( |
| 259 | self.config.agentlightning.trace_aggregator.trajectory_max_prompt_length |
| 260 | if self.config.agentlightning.trace_aggregator.level.startswith("trajectory") |
| 261 | else self.config.data.max_prompt_length |
| 262 | ), |
| 263 | max_response_length=( |
| 264 | self.config.agentlightning.trace_aggregator.trajectory_max_response_length |
| 265 | if self.config.agentlightning.trace_aggregator.level.startswith("trajectory") |
| 266 | else self.config.data.max_response_length |
| 267 | ), |
| 268 | device=gen_batch.batch["fake_ids"].device, |
| 269 | global_steps=self.global_steps, |
| 270 | ) |
| 271 | metrics.update(agent_metrics) |
| 272 | self.agent_mode_daemon.clear_data_and_server() |
| 273 | self.async_rollout_manager.sleep() |
| 274 | |
| 275 | if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: |
| 276 | with _timer("gen_max", timing_raw): |
| 277 | gen_baseline_batch = deepcopy(gen_batch) |
| 278 | gen_baseline_batch.meta_info["do_sample"] = False |
| 279 | gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) |
| 280 | |
| 281 | batch = batch.union(gen_baseline_output) |
| 282 | reward_baseline_tensor = self.reward_fn(batch) |
| 283 | reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) |
| 284 | |
| 285 | batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) |
| 286 | |
| 287 | batch.batch["reward_baselines"] = reward_baseline_tensor |
| 288 | |
| 289 | del gen_baseline_batch, gen_baseline_output |
| 290 | |
| 291 | # uid is used for algorithm like GRPO, should be aligned to data id |
| 292 | batch.non_tensor_batch["uid"] = batch.non_tensor_batch["data_id_list"] |
| 293 | |
| 294 | if "response_mask" not in batch.batch: |
| 295 | batch.batch["response_mask"] = compute_response_mask(batch) |
| 296 |
no test coverage detected