(self, data: DataProto)
| 398 | |
| 399 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| 400 | def update_actor(self, data: DataProto): |
| 401 | data = data.to('cuda') |
| 402 | |
| 403 | assert self._is_actor |
| 404 | if self._is_offload_param: |
| 405 | load_fsdp_param_and_grad(module=self.actor_module_fsdp, |
| 406 | device_id=torch.cuda.current_device(), |
| 407 | load_grad=self._is_offload_grad) |
| 408 | if self._is_offload_optimizer: |
| 409 | load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) |
| 410 | |
| 411 | data.batch = data.batch.cuda() |
| 412 | |
| 413 | log_gpu_memory_usage('Before update policy', logger=logger) |
| 414 | |
| 415 | with self.ulysses_sharding_manager: |
| 416 | data = self.ulysses_sharding_manager.preprocess_data(data=data) |
| 417 | # perform training |
| 418 | with Timer(name='update_policy', logger=None) as timer: |
| 419 | metrics = self.actor.update_policy(data=data) |
| 420 | delta_time = timer.last |
| 421 | global_num_tokens = data.meta_info['global_token_num'] |
| 422 | estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) |
| 423 | metrics['mfu/actor'] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size |
| 424 | |
| 425 | self.actor_lr_scheduler.step() |
| 426 | lr = self.actor_lr_scheduler.get_last_lr()[0] |
| 427 | metrics['actor/lr'] = lr |
| 428 | |
| 429 | log_gpu_memory_usage('After update policy', logger=logger) |
| 430 | |
| 431 | # TODO: here, we should return all metrics |
| 432 | output = DataProto(meta_info={'metrics': metrics}) |
| 433 | |
| 434 | output = self.ulysses_sharding_manager.postprocess_data(data=output) |
| 435 | output = output.to('cpu') |
| 436 | |
| 437 | if self._is_offload_param: |
| 438 | offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) |
| 439 | if self._is_offload_optimizer: |
| 440 | offload_fsdp_optimizer(optimizer=self.actor_optimizer) |
| 441 | torch.cuda.empty_cache() |
| 442 | return output |
| 443 | |
| 444 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| 445 | def generate_sequences(self, prompts: DataProto): |
no test coverage detected