MCPcopy
hub / github.com/Jiayi-Pan/TinyZero / update_actor

Method update_actor

verl/workers/fsdp_workers.py:356–398  ·  view source on GitHub ↗
(self, data: DataProto)

Source from the content-addressed store, hash-verified

354
355 @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
356 def update_actor(self, data: DataProto):
357 data = data.to('cuda')
358
359 assert self._is_actor
360 if self._is_offload_param:
361 load_fsdp_param_and_grad(module=self.actor_module_fsdp,
362 device_id=torch.cuda.current_device(),
363 load_grad=self._is_offload_grad)
364 if self._is_offload_optimizer:
365 load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=torch.cuda.current_device())
366
367 data.batch = data.batch.cuda()
368
369 log_gpu_memory_usage('Before update policy', logger=logger)
370
371 with self.ulysses_sharding_manager:
372 data = self.ulysses_sharding_manager.preprocess_data(data=data)
373 # perform training
374 with Timer(name='update_policy', logger=None) as timer:
375 metrics = self.actor.update_policy(data=data)
376 delta_time = timer.last
377 global_num_tokens = data.meta_info['global_token_num']
378 estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
379 metrics['mfu/actor'] = estimated_flops * self.config.actor.ppo_epochs / promised_flops / self.world_size
380
381 self.actor_lr_scheduler.step()
382 lr = self.actor_lr_scheduler.get_last_lr()[0]
383 metrics['actor/lr'] = lr
384
385 log_gpu_memory_usage('After update policy', logger=logger)
386
387 # TODO: here, we should return all metrics
388 output = DataProto(meta_info={'metrics': metrics})
389
390 output = self.ulysses_sharding_manager.postprocess_data(data=output)
391 output = output.to('cpu')
392
393 if self._is_offload_param:
394 offload_fsdp_param_and_grad(module=self.actor_module_fsdp, offload_grad=self._is_offload_grad)
395 if self._is_offload_optimizer:
396 offload_fsdp_optimizer(optimizer=self.actor_optimizer)
397 torch.cuda.empty_cache()
398 return output
399
400 @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
401 def generate_sequences(self, prompts: DataProto):

Callers 2

fitMethod · 0.45
fitFunction · 0.45

Calls 12

toMethod · 0.95
load_fsdp_param_and_gradFunction · 0.90
load_fsdp_optimizerFunction · 0.90
log_gpu_memory_usageFunction · 0.90
DataProtoClass · 0.90
offload_fsdp_optimizerFunction · 0.90
estimate_flopsMethod · 0.80
stepMethod · 0.80
preprocess_dataMethod · 0.45
update_policyMethod · 0.45
postprocess_dataMethod · 0.45

Tested by

no test coverage detected