Computes a confidence-aware training loss for token classification-style heads. This loss combines: - `loss_sft`: standard supervised cross-entropy on all non-ignored labels. - `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly. Arg
(
logits: torch.Tensor,
labels: torch.Tensor,
*,
lambda_conf: float = 0.0,
temperature: float = 1.0,
per_token_weights: torch.Tensor | None = None,
ignore_index: int = -100,
)
| 111 | |
| 112 | |
| 113 | def compute_confidence_aware_loss( |
| 114 | logits: torch.Tensor, |
| 115 | labels: torch.Tensor, |
| 116 | *, |
| 117 | lambda_conf: float = 0.0, |
| 118 | temperature: float = 1.0, |
| 119 | per_token_weights: torch.Tensor | None = None, |
| 120 | ignore_index: int = -100, |
| 121 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| 122 | """ |
| 123 | Computes a confidence-aware training loss for token classification-style heads. |
| 124 | |
| 125 | This loss combines: |
| 126 | - `loss_sft`: standard supervised cross-entropy on all non-ignored labels. |
| 127 | - `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly. |
| 128 | |
| 129 | Args: |
| 130 | logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`. |
| 131 | labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index` |
| 132 | are excluded from both losses. |
| 133 | lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term. |
| 134 | temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values |
| 135 | sharpen the distribution and change the strength of the confidence gradients. |
| 136 | per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per |
| 137 | token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing. |
| 138 | ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels. |
| 139 | |
| 140 | Returns: |
| 141 | `Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`. |
| 142 | """ |
| 143 | if logits.ndim < 2: |
| 144 | raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.") |
| 145 | if labels.shape != logits.shape[:-1]: |
| 146 | raise ValueError( |
| 147 | f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}." |
| 148 | ) |
| 149 | if temperature <= 0: |
| 150 | raise ValueError(f"`temperature` must be > 0, got {temperature}.") |
| 151 | |
| 152 | valid = labels.ne(ignore_index) |
| 153 | if per_token_weights is None: |
| 154 | weights = torch.ones_like(labels, dtype=logits.dtype) |
| 155 | else: |
| 156 | if per_token_weights.shape != labels.shape: |
| 157 | raise ValueError( |
| 158 | f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}." |
| 159 | ) |
| 160 | weights = per_token_weights.to(dtype=logits.dtype) |
| 161 | |
| 162 | # Supervised CE (optionally weighted). |
| 163 | vocab_size = logits.shape[-1] |
| 164 | per_token_nll = F.cross_entropy( |
| 165 | logits.reshape(-1, vocab_size), |
| 166 | labels.reshape(-1), |
| 167 | reduction="none", |
| 168 | ignore_index=ignore_index, |
| 169 | ).reshape_as(labels) |
| 170 |
searching dependent graphs…