| 57 | |
| 58 | |
| 59 | class InvalidScoreLogitsProcessor(LogitsProcessor): |
| 60 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: |
| 61 | if torch.isnan(scores).any() or torch.isinf(scores).any(): |
| 62 | scores.zero_() |
| 63 | scores[..., 5] = 5e4 |
| 64 | return scores |
| 65 | |
| 66 | def split_tensor_along_last_dim( |
| 67 | tensor: torch.Tensor, |