MCPcopy
hub / github.com/PRIME-RL/PRIME / DataParallelPPOActor

Class DataParallelPPOActor

training/verl/workers/actor/dp_actor.py:33–209  ·  view source on GitHub ↗

Source from the content-addressed store, hash-verified

31__all__ = ['DataParallelPPOActor']
32
33class DataParallelPPOActor(BasePPOActor):
34
35 def __init__(
36 self,
37 config,
38 actor_module: nn.Module,
39 actor_optimizer: torch.optim.Optimizer = None,
40 ):
41 """When optimizer is None, it is Reference Policy"""
42 super().__init__(config)
43 self.actor_module = actor_module
44 self.actor_optimizer = actor_optimizer
45 self.use_remove_padding = self.config.get('use_remove_padding', False)
46 print(f'Actor use_remove_padding={self.use_remove_padding}')
47
48
49 def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
50 response_length = micro_batch['responses'].size(-1)
51 with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
52 input_ids = micro_batch['input_ids']
53 batch_size, seqlen = input_ids.shape
54 attention_mask = micro_batch['attention_mask']
55 position_ids = micro_batch['position_ids']
56 if self.use_remove_padding:
57 input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
58 input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
59 input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
60
61 # unpad the position_ids to align the rotary
62 position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
63 indices).transpose(0, 1)
64 # only pass input_ids and position_ids to enable flash_attn_varlen
65 output = self.actor_module(input_ids=input_ids_rmpad,
66 attention_mask=None,
67 position_ids=position_ids_rmpad,
68 use_cache=False) # prevent model thinks we are generating
69 logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
70 logits_rmpad /= temperature
71 log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
72 logits_rmpad=logits_rmpad,
73 indices=indices,
74 batch_size=batch_size,
75 seqlen=seqlen,
76 response_length=response_length) # (batch, seqlen)
77 logits = logits_rmpad
78 else:
79 output = self.actor_module(input_ids=input_ids,
80 attention_mask=attention_mask,
81 position_ids=position_ids,
82 use_cache=False) # prevent model thinks we are generating
83 logits = output.logits / temperature
84 logits = logits[:, -response_length - 1:-1]
85 log_probs = logprobs_from_logits(logits, micro_batch['responses'])
86
87 return logits, log_probs
88
89 def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
90 """Make minibatch iterator for updating the actor

Callers 1

init_modelMethod · 0.90

Calls

no outgoing calls

Tested by

no test coverage detected