(self, data: DataProto)
| 311 | |
| 312 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| 313 | def update_actor(self, data: DataProto): |
| 314 | data = data.to('cuda') |
| 315 | |
| 316 | assert self._is_actor |
| 317 | if self._is_offload_param: |
| 318 | load_fsdp_param_and_grad(module=self.actor_module_fsdp, |
| 319 | device_id=torch.cuda.current_device(), |
| 320 | load_grad=self._is_offload_grad) |
| 321 | if self._is_offload_optimizer: |
| 322 | load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device()) |
| 323 | |
| 324 | data.batch = data.batch.cuda() |
| 325 | |
| 326 | log_gpu_memory_usage('Before update policy', logger=logger) |
| 327 | |
| 328 | metrics = self.actor.update_policy(data=data) |
| 329 | |
| 330 | self.actor_lr_scheduler.step() |
| 331 | lr = self.actor_lr_scheduler.get_last_lr()[0] |
| 332 | metrics['actor/lr(1e-4)'] = lr * 1e4 |
| 333 | |
| 334 | log_gpu_memory_usage('After update policy', logger=logger) |
| 335 | |
| 336 | # TODO: here, we should return all metrics |
| 337 | output = DataProto(meta_info={'metrics': metrics}) |
| 338 | output = output.to('cpu') |
| 339 | |
| 340 | if self._is_offload_param: |
| 341 | offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad) |
| 342 | if self._is_offload_optimizer: |
| 343 | offload_fsdp_optimizer(optimizer=self.actor_optimizer) |
| 344 | torch.cuda.synchronize() |
| 345 | torch.distributed.barrier() |
| 346 | torch.cuda.empty_cache() |
| 347 | return output |
| 348 | |
| 349 | @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) |
| 350 | def generate_sequences(self, prompts: DataProto): |
no test coverage detected