MCPcopy
hub / github.com/hkust-nlp/simpleRL-reason / update_critic

Method update_critic

verl/workers/fsdp_workers.py:788–822  ·  view source on GitHub ↗
(self, data: DataProto)

Source from the content-addressed store, hash-verified

786
787 @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
788 def update_critic(self, data: DataProto):
789 data = data.to('cuda')
790 if self._is_offload_param:
791 load_fsdp_param_and_grad(module=self.critic_module,
792 device_id=torch.cuda.current_device(),
793 load_grad=self._is_offload_grad)
794 if self._is_offload_optimizer:
795 load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device())
796
797 # perform forward computation
798 with self.ulysses_sharding_manager:
799 data = self.ulysses_sharding_manager.preprocess_data(data=data)
800
801 with Timer(name='update_critic', logger=None) as timer:
802 metrics = self.critic.update_critic(data=data)
803 delta_time = timer.last
804
805 global_num_tokens = data.meta_info['global_token_num']
806 estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
807 metrics['mfu/critic'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
808
809 self.critic_lr_scheduler.step()
810 lr = self.critic_lr_scheduler.get_last_lr()[0]
811 metrics['critic/lr'] = lr
812
813 output = DataProto(batch=None, meta_info={'metrics': metrics})
814 output = self.ulysses_sharding_manager.postprocess_data(data=output)
815
816 if self._is_offload_param:
817 offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad)
818 if self._is_offload_optimizer:
819 offload_fsdp_optimizer(optimizer=self.critic_optimizer)
820 torch.cuda.empty_cache()
821 output = output.to('cpu')
822 return output
823
824 @register(dispatch_mode=Dispatch.ONE_TO_ALL)
825 def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=True):

Callers 1

fitMethod · 0.45

Calls 10

toMethod · 0.95
load_fsdp_param_and_gradFunction · 0.90
load_fsdp_optimizerFunction · 0.90
DataProtoClass · 0.90
offload_fsdp_optimizerFunction · 0.90
estimate_flopsMethod · 0.80
stepMethod · 0.80
preprocess_dataMethod · 0.45
postprocess_dataMethod · 0.45

Tested by

no test coverage detected