(self, data: DataProto)
| 354 | |
| 355 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| 356 | def update_actor(self, data: DataProto): |
| 357 | data = data.to('cuda') |
| 358 | |
| 359 | assert self._is_actor |
| 360 | if self._is_offload_param: |
| 361 | load_fsdp_param_and_grad(module=self.actor_module_fsdp, |
| 362 | device_id=torch.cuda.current_device(), |
| 363 | load_grad=self._is_offload_grad) |
| 364 | if self._is_offload_optimizer: |
| 365 | load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) |
| 366 | |
| 367 | data.batch = data.batch.cuda() |
| 368 | |
| 369 | log_gpu_memory_usage('Before update policy', logger=logger) |
| 370 | |
| 371 | with self.ulysses_sharding_manager: |
| 372 | data = self.ulysses_sharding_manager.preprocess_data(data=data) |
| 373 | # perform training |
| 374 | with Timer(name='update_policy', logger=None) as timer: |
| 375 | metrics = self.actor.update_policy(data=data) |
| 376 | delta_time = timer.last |
| 377 | global_num_tokens = data.meta_info['global_token_num'] |
| 378 | estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time) |
| 379 | metrics['mfu/actor'] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size |
| 380 | |
| 381 | self.actor_lr_scheduler.step() |
| 382 | lr = self.actor_lr_scheduler.get_last_lr()[0] |
| 383 | metrics['actor/lr'] = lr |
| 384 | |
| 385 | log_gpu_memory_usage('After update policy', logger=logger) |
| 386 | |
| 387 | # TODO: here, we should return all metrics |
| 388 | output = DataProto(meta_info={'metrics': metrics}) |
| 389 | |
| 390 | output = self.ulysses_sharding_manager.postprocess_data(data=output) |
| 391 | output = output.to('cpu') |
| 392 | |
| 393 | if self._is_offload_param: |
| 394 | offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) |
| 395 | if self._is_offload_optimizer: |
| 396 | offload_fsdp_optimizer(optimizer=self.actor_optimizer) |
| 397 | torch.cuda.empty_cache() |
| 398 | return output |
| 399 | |
| 400 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| 401 | def generate_sequences(self, prompts: DataProto): |
no test coverage detected