Numerically stable log(1 - exp(x)) for x < 0.
(x: torch.Tensor)
| 41 | |
| 42 | |
| 43 | def _log1mexp(x: torch.Tensor) -> torch.Tensor: |
| 44 | """Numerically stable log(1 - exp(x)) for x < 0.""" |
| 45 | return torch.where(x > -0.6931, torch.log(-torch.expm1(x)), torch.log1p(-torch.exp(x))) |
| 46 | |
| 47 | |
| 48 | def orpo_loss( |