The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process.
(self)
| 23 | |
| 24 | |
| 25 | def fit(self): |
| 26 | """ |
| 27 | The training loop of PPO. |
| 28 | The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. |
| 29 | The light-weight advantage computation is done on the driver process. |
| 30 | """ |
| 31 | from verl.utils.tracking import Tracking |
| 32 | from omegaconf import OmegaConf |
| 33 | |
| 34 | logger = Tracking(project_name=self.config.trainer.project_name, |
| 35 | experiment_name=self.config.trainer.experiment_name, |
| 36 | default_backend=self.config.trainer.logger, |
| 37 | config=OmegaConf.to_container(self.config, resolve=True)) |
| 38 | |
| 39 | global_steps = 0 |
| 40 | |
| 41 | # perform validation before training |
| 42 | # currently, we only support validation using the reward_function. |
| 43 | if self.val_reward_fn is not None: |
| 44 | val_metrics = self._validate() |
| 45 | pprint(f'Initial validation metrics: {val_metrics}') |
| 46 | |
| 47 | for epoch in range(self.config.trainer.total_epochs): |
| 48 | for batch_dict in self.train_dataloader: |
| 49 | metrics = {} |
| 50 | |
| 51 | batch: DataProto = DataProto.from_single_dict(batch_dict) |
| 52 | # batch = batch.to('cuda') |
| 53 | |
| 54 | # pop those keys for generation |
| 55 | gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) |
| 56 | |
| 57 | # generate a batch |
| 58 | with Timer(name='gen', logger=None) as timer: |
| 59 | gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) |
| 60 | metrics['timing/gen'] = timer.last |
| 61 | |
| 62 | batch = batch.union(gen_batch_output) |
| 63 | |
| 64 | if self.use_reference_policy: |
| 65 | # compute reference log_prob |
| 66 | with Timer(name='ref', logger=None) as timer: |
| 67 | ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) |
| 68 | batch = batch.union(ref_log_prob) |
| 69 | metrics['timing/ref'] = timer.last |
| 70 | |
| 71 | # compute values |
| 72 | with Timer(name='values', logger=None) as timer: |
| 73 | values = self.critic_wg.compute_values(batch) |
| 74 | batch = batch.union(values) |
| 75 | metrics['timing/values'] = timer.last |
| 76 | |
| 77 | with Timer(name='adv', logger=None) as timer: |
| 78 | # compute scores. Support both model and function-based. |
| 79 | # We first compute the scores using reward model. Then, we call reward_fn to combine |
| 80 | # the results from reward model and rule-based results. |
| 81 | if self.use_rm: |
| 82 | # we first compute reward model score |
nothing calls this directly
no test coverage detected