MCPcopy
hub / github.com/Jiayi-Pan/TinyZero / fit

Function fit

examples/split_placement/split_monkey_patch.py:25–161  ·  view source on GitHub ↗

The training loop of PPO. The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow. The light-weight advantage computation is done on the driver process.

(self)

Source from the content-addressed store, hash-verified

23
24
25def fit(self):
26 """
27 The training loop of PPO.
28 The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
29 The light-weight advantage computation is done on the driver process.
30 """
31 from verl.utils.tracking import Tracking
32 from omegaconf import OmegaConf
33
34 logger = Tracking(project_name=self.config.trainer.project_name,
35 experiment_name=self.config.trainer.experiment_name,
36 default_backend=self.config.trainer.logger,
37 config=OmegaConf.to_container(self.config, resolve=True))
38
39 global_steps = 0
40
41 # perform validation before training
42 # currently, we only support validation using the reward_function.
43 if self.val_reward_fn is not None:
44 val_metrics = self._validate()
45 pprint(f'Initial validation metrics: {val_metrics}')
46
47 for epoch in range(self.config.trainer.total_epochs):
48 for batch_dict in self.train_dataloader:
49 metrics = {}
50
51 batch: DataProto = DataProto.from_single_dict(batch_dict)
52 # batch = batch.to('cuda')
53
54 # pop those keys for generation
55 gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])
56
57 # generate a batch
58 with Timer(name='gen', logger=None) as timer:
59 gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
60 metrics['timing/gen'] = timer.last
61
62 batch = batch.union(gen_batch_output)
63
64 if self.use_reference_policy:
65 # compute reference log_prob
66 with Timer(name='ref', logger=None) as timer:
67 ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
68 batch = batch.union(ref_log_prob)
69 metrics['timing/ref'] = timer.last
70
71 # compute values
72 with Timer(name='values', logger=None) as timer:
73 values = self.critic_wg.compute_values(batch)
74 batch = batch.union(values)
75 metrics['timing/values'] = timer.last
76
77 with Timer(name='adv', logger=None) as timer:
78 # compute scores. Support both model and function-based.
79 # We first compute the scores using reward model. Then, we call reward_fn to combine
80 # the results from reward model and rule-based results.
81 if self.use_rm:
82 # we first compute reward model score

Callers

nothing calls this directly

Calls 15

logMethod · 0.95
TrackingClass · 0.90
apply_kl_penaltyFunction · 0.90
compute_advantageFunction · 0.90
reduce_metricsFunction · 0.90
compute_data_metricsFunction · 0.90
_validateMethod · 0.80
from_single_dictMethod · 0.80
popMethod · 0.80
unionMethod · 0.80
generate_sequencesMethod · 0.45
compute_ref_log_probMethod · 0.45

Tested by

no test coverage detected