MCPcopy
hub / github.com/hkust-nlp/simpleRL-reason / update_actor

Method update_actor

verl/workers/fsdp_workers.py:400–442  ·  view source on GitHub ↗
(self, data: DataProto)

Source from the content-addressed store, hash-verified

398
399 @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
400 def update_actor(self, data: DataProto):
401 data = data.to('cuda')
402
403 assert self._is_actor
404 if self._is_offload_param:
405 load_fsdp_param_and_grad(module=self.actor_module_fsdp,
406 device_id=torch.cuda.current_device(),
407 load_grad=self._is_offload_grad)
408 if self._is_offload_optimizer:
409 load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device())
410
411 data.batch = data.batch.cuda()
412
413 log_gpu_memory_usage('Before update policy', logger=logger)
414
415 with self.ulysses_sharding_manager:
416 data = self.ulysses_sharding_manager.preprocess_data(data=data)
417 # perform training
418 with Timer(name='update_policy', logger=None) as timer:
419 metrics = self.actor.update_policy(data=data)
420 delta_time = timer.last
421 global_num_tokens = data.meta_info['global_token_num']
422 estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
423 metrics['mfu/actor'] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
424
425 self.actor_lr_scheduler.step()
426 lr = self.actor_lr_scheduler.get_last_lr()[0]
427 metrics['actor/lr'] = lr
428
429 log_gpu_memory_usage('After update policy', logger=logger)
430
431 # TODO: here, we should return all metrics
432 output = DataProto(meta_info={'metrics': metrics})
433
434 output = self.ulysses_sharding_manager.postprocess_data(data=output)
435 output = output.to('cpu')
436
437 if self._is_offload_param:
438 offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
439 if self._is_offload_optimizer:
440 offload_fsdp_optimizer(optimizer=self.actor_optimizer)
441 torch.cuda.empty_cache()
442 return output
443
444 @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
445 def generate_sequences(self, prompts: DataProto):

Callers 1

fitMethod · 0.45

Calls 12

toMethod · 0.95
load_fsdp_param_and_gradFunction · 0.90
load_fsdp_optimizerFunction · 0.90
log_gpu_memory_usageFunction · 0.90
DataProtoClass · 0.90
offload_fsdp_optimizerFunction · 0.90
estimate_flopsMethod · 0.80
stepMethod · 0.80
preprocess_dataMethod · 0.45
update_policyMethod · 0.45
postprocess_dataMethod · 0.45

Tested by

no test coverage detected