MCPcopy
hub / github.com/QwenLM/Qwen-Audio / StopWordsLogitsProcessor

Class StopWordsLogitsProcessor

qwen_generation_utils.py:316–396  ·  view source on GitHub ↗

:class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration. Args: stop_words_ids (:obj:`List[List[int]]`): List of list of token ids of stop ids. In order to get the tokens of the words that should not appear in t

Source from the content-addressed store, hash-verified

314
315
316class StopWordsLogitsProcessor(LogitsProcessor):
317 """
318 :class:`transformers.LogitsProcessor` that enforces that when specified sequences appear, stop geration.
319
320 Args:
321 stop_words_ids (:obj:`List[List[int]]`):
322 List of list of token ids of stop ids. In order to get the tokens of the words
323 that should not appear in the generated text, use :obj:`tokenizer(bad_word,
324 add_prefix_space=True).input_ids`.
325 eos_token_id (:obj:`int`):
326 The id of the `end-of-sequence` token.
327 """
328
329 def __init__(self, stop_words_ids: Iterable[Iterable[int]], eos_token_id: int):
330
331 if not isinstance(stop_words_ids, List) or len(stop_words_ids) == 0:
332 raise ValueError(
333 f"`stop_words_ids` has to be a non-emtpy list, but is {stop_words_ids}."
334 )
335 if any(not isinstance(bad_word_ids, list) for bad_word_ids in stop_words_ids):
336 raise ValueError(
337 f"`stop_words_ids` has to be a list of lists, but is {stop_words_ids}."
338 )
339 if any(
340 any(
341 (not isinstance(token_id, (int, np.integer)) or token_id < 0)
342 for token_id in stop_word_ids
343 )
344 for stop_word_ids in stop_words_ids
345 ):
346 raise ValueError(
347 f"Each list in `stop_words_ids` has to be a list of positive integers, but is {stop_words_ids}."
348 )
349
350 self.stop_words_ids = list(
351 filter(
352 lambda bad_token_seq: bad_token_seq != [eos_token_id], stop_words_ids
353 )
354 )
355 self.eos_token_id = eos_token_id
356 for stop_token_seq in self.stop_words_ids:
357 assert (
358 len(stop_token_seq) > 0
359 ), "Stop words token sequences {} cannot have an empty list".format(
360 stop_words_ids
361 )
362
363 def __call__(
364 self, input_ids: torch.LongTensor, scores: torch.FloatTensor
365 ) -> torch.FloatTensor:
366 stopped_samples = self._calc_stopped_samples(input_ids)
367 for i, should_stop in enumerate(stopped_samples):
368 if should_stop:
369 scores[i, self.eos_token_id] = float(2**15)
370 return scores
371
372 def _tokens_match(self, prev_tokens: torch.LongTensor, tokens: List[int]) -> bool:
373 if len(tokens) == 0:

Callers 2

chat_streamMethod · 0.85
generateMethod · 0.85

Calls

no outgoing calls

Tested by

no test coverage detected