| 31 | __all__ = ['DataParallelPPOActor'] |
| 32 | |
| 33 | class DataParallelPPOActor(BasePPOActor): |
| 34 | |
| 35 | def __init__( |
| 36 | self, |
| 37 | config, |
| 38 | actor_module: nn.Module, |
| 39 | actor_optimizer: torch.optim.Optimizer = None, |
| 40 | ): |
| 41 | """When optimizer is None, it is Reference Policy""" |
| 42 | super().__init__(config) |
| 43 | self.actor_module = actor_module |
| 44 | self.actor_optimizer = actor_optimizer |
| 45 | self.use_remove_padding = self.config.get('use_remove_padding', False) |
| 46 | print(f'Actor use_remove_padding={self.use_remove_padding}') |
| 47 | |
| 48 | |
| 49 | def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]: |
| 50 | response_length = micro_batch['responses'].size(-1) |
| 51 | with torch.autocast(device_type='cuda', dtype=torch.bfloat16): |
| 52 | input_ids = micro_batch['input_ids'] |
| 53 | batch_size, seqlen = input_ids.shape |
| 54 | attention_mask = micro_batch['attention_mask'] |
| 55 | position_ids = micro_batch['position_ids'] |
| 56 | if self.use_remove_padding: |
| 57 | input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input( |
| 58 | input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) |
| 59 | input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) |
| 60 | |
| 61 | # unpad the position_ids to align the rotary |
| 62 | position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), |
| 63 | indices).transpose(0, 1) |
| 64 | # only pass input_ids and position_ids to enable flash_attn_varlen |
| 65 | output = self.actor_module(input_ids=input_ids_rmpad, |
| 66 | attention_mask=None, |
| 67 | position_ids=position_ids_rmpad, |
| 68 | use_cache=False) # prevent model thinks we are generating |
| 69 | logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size) |
| 70 | logits_rmpad /= temperature |
| 71 | log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad, |
| 72 | logits_rmpad=logits_rmpad, |
| 73 | indices=indices, |
| 74 | batch_size=batch_size, |
| 75 | seqlen=seqlen, |
| 76 | response_length=response_length) # (batch, seqlen) |
| 77 | logits = logits_rmpad |
| 78 | else: |
| 79 | output = self.actor_module(input_ids=input_ids, |
| 80 | attention_mask=attention_mask, |
| 81 | position_ids=position_ids, |
| 82 | use_cache=False) # prevent model thinks we are generating |
| 83 | logits = output.logits / temperature |
| 84 | logits = logits[:, -response_length - 1:-1] |
| 85 | log_probs = logprobs_from_logits(logits, micro_batch['responses']) |
| 86 | |
| 87 | return logits, log_probs |
| 88 | |
| 89 | def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: |
| 90 | """Make minibatch iterator for updating the actor |