(self, data: DataProto)
| 786 | |
| 787 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| 788 | def update_critic(self, data: DataProto): |
| 789 | data = data.to('cuda') |
| 790 | if self._is_offload_param: |
| 791 | load_fsdp_param_and_grad(module=self.critic_module, |
| 792 | device_id=torch.cuda.current_device(), |
| 793 | load_grad=self._is_offload_grad) |
| 794 | if self._is_offload_optimizer: |
| 795 | load_fsdp_optimizer(optimizer=self.critic_optimizer, device_id=torch.cuda.current_device()) |
| 796 | |
| 797 | # perform forward computation |
| 798 | with self.ulysses_sharding_manager: |
| 799 | data = self.ulysses_sharding_manager.preprocess_data(data=data) |
| 800 | |
| 801 | with Timer(name='update_critic', logger=None) as timer: |
| 802 | metrics = self.critic.update_critic(data=data) |
| 803 | delta_time = timer.last |
| 804 | |
| 805 | global_num_tokens = data.meta_info['global_token_num'] |
| 806 | estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) |
| 807 | metrics['mfu/critic'] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size |
| 808 | |
| 809 | self.critic_lr_scheduler.step() |
| 810 | lr = self.critic_lr_scheduler.get_last_lr()[0] |
| 811 | metrics['critic/lr'] = lr |
| 812 | |
| 813 | output = DataProto(batch=None, meta_info={'metrics': metrics}) |
| 814 | output = self.ulysses_sharding_manager.postprocess_data(data=output) |
| 815 | |
| 816 | if self._is_offload_param: |
| 817 | offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad) |
| 818 | if self._is_offload_optimizer: |
| 819 | offload_fsdp_optimizer(optimizer=self.critic_optimizer) |
| 820 | torch.cuda.empty_cache() |
| 821 | output = output.to('cpu') |
| 822 | return output |
| 823 | |
| 824 | @register(dispatch_mode=Dispatch.ONE_TO_ALL) |
| 825 | def save_checkpoint(self, local_path, hdfs_path=None, global_step=0, remove_previous_ckpt=True): |
no test coverage detected