(
self,
inputs: Optional[torch.Tensor] = None,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[
Callable[[int, torch.Tensor], List[int]]
] = None,
synced_gpus: Optional[bool] = None,
assistant_model: Optional["PreTrainedModel"] = None,
streamer: Optional["BaseStreamer"] = None,
**kwargs,
)
| 1293 | return stream_generator() |
| 1294 | |
| 1295 | def generate( |
| 1296 | self, |
| 1297 | inputs: Optional[torch.Tensor] = None, |
| 1298 | generation_config: Optional[GenerationConfig] = None, |
| 1299 | logits_processor: Optional[LogitsProcessorList] = None, |
| 1300 | stopping_criteria: Optional[StoppingCriteriaList] = None, |
| 1301 | prefix_allowed_tokens_fn: Optional[ |
| 1302 | Callable[[int, torch.Tensor], List[int]] |
| 1303 | ] = None, |
| 1304 | synced_gpus: Optional[bool] = None, |
| 1305 | assistant_model: Optional["PreTrainedModel"] = None, |
| 1306 | streamer: Optional["BaseStreamer"] = None, |
| 1307 | **kwargs, |
| 1308 | ) -> Union[GenerateOutput, torch.LongTensor]: |
| 1309 | generation_config = generation_config if generation_config is not None else self.generation_config |
| 1310 | |
| 1311 | # Process stop_words_ids. |
| 1312 | stop_words_ids = kwargs.pop("stop_words_ids", None) |
| 1313 | if stop_words_ids is None and generation_config is not None: |
| 1314 | stop_words_ids = getattr(generation_config, "stop_words_ids", None) |
| 1315 | if stop_words_ids is None: |
| 1316 | stop_words_ids = getattr(generation_config, "stop_words_ids", None) |
| 1317 | |
| 1318 | if stop_words_ids is not None: |
| 1319 | stop_words_logits_processor = StopWordsLogitsProcessor( |
| 1320 | stop_words_ids=stop_words_ids, |
| 1321 | eos_token_id=generation_config.eos_token_id, |
| 1322 | ) |
| 1323 | if logits_processor is None: |
| 1324 | logits_processor = LogitsProcessorList([stop_words_logits_processor]) |
| 1325 | else: |
| 1326 | logits_processor.append(stop_words_logits_processor) |
| 1327 | |
| 1328 | return super().generate( |
| 1329 | inputs, |
| 1330 | generation_config=generation_config, |
| 1331 | logits_processor=logits_processor, |
| 1332 | stopping_criteria=stopping_criteria, |
| 1333 | prefix_allowed_tokens_fn=prefix_allowed_tokens_fn, |
| 1334 | synced_gpus=synced_gpus, |
| 1335 | assistant_model=assistant_model, |
| 1336 | streamer=streamer, |
| 1337 | **kwargs, |
| 1338 | ) |
| 1339 | |
| 1340 | |
| 1341 | class RotaryEmbedding(torch.nn.Module): |
no test coverage detected