:class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. Args: stop_words_ids (:obj:`List[List[int]]`): List of list of token ids of stop ids. In order to get the tokens of the words that should not appear in t
| 314 | |
| 315 | |
| 316 | class StopWordsLogitsProcessor(LogitsProcessor): |
| 317 | """ |
| 318 | :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. |
| 319 | |
| 320 | Args: |
| 321 | stop_words_ids (:obj:`List[List[int]]`): |
| 322 | List of list of token ids of stop ids. In order to get the tokens of the words |
| 323 | that should not appear in the generated text, use :obj:`tokenizer(bad_word, |
| 324 | add_prefix_space=True).input_ids`. |
| 325 | eos_token_id (:obj:`int`): |
| 326 | The id of the `end-of-sequence` token. |
| 327 | """ |
| 328 | |
| 329 | def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int): |
| 330 | |
| 331 | if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0: |
| 332 | raise ValueError( |
| 333 | f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}." |
| 334 | ) |
| 335 | if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids): |
| 336 | raise ValueError( |
| 337 | f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}." |
| 338 | ) |
| 339 | if any( |
| 340 | any( |
| 341 | (not isinstance(token_id, (int, np.integer)) or token_id < 0) |
| 342 | for token_id in stop_word_ids |
| 343 | ) |
| 344 | for stop_word_ids in stop_words_ids |
| 345 | ): |
| 346 | raise ValueError( |
| 347 | f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}." |
| 348 | ) |
| 349 | |
| 350 | self.stop_words_ids = list( |
| 351 | filter( |
| 352 | lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids |
| 353 | ) |
| 354 | ) |
| 355 | self.eos_token_id = eos_token_id |
| 356 | for stop_token_seq in self.stop_words_ids: |
| 357 | assert ( |
| 358 | len(stop_token_seq) > 0 |
| 359 | ), "Stop words token sequences {} cannot have an empty list".format( |
| 360 | stop_words_ids |
| 361 | ) |
| 362 | |
| 363 | def __call__( |
| 364 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor |
| 365 | ) -> torch.FloatTensor: |
| 366 | stopped_samples = self._calc_stopped_samples(input_ids) |
| 367 | for i, should_stop in enumerate(stopped_samples): |
| 368 | if should_stop: |
| 369 | scores[i, self.eos_token_id] = float(2**15) |
| 370 | return scores |
| 371 | |
| 372 | def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool: |
| 373 | if len(tokens) == 0: |
no outgoing calls
no test coverage detected