(
model,
tokenizer,
prompt,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
additional_eos_token_id: Optional[int] = None,
**kwargs,
)
| 23 | |
| 24 | @torch.inference_mode() |
| 25 | def generate_interactive( |
| 26 | model, |
| 27 | tokenizer, |
| 28 | prompt, |
| 29 | generation_config: Optional[GenerationConfig] = None, |
| 30 | logits_processor: Optional[LogitsProcessorList] = None, |
| 31 | stopping_criteria: Optional[StoppingCriteriaList] = None, |
| 32 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, |
| 33 | additional_eos_token_id: Optional[int] = None, |
| 34 | **kwargs, |
| 35 | ): |
| 36 | inputs = tokenizer([prompt], padding=True, return_tensors="pt") |
| 37 | input_length = len(inputs["input_ids"][0]) |
| 38 | for k, v in inputs.items(): |
| 39 | inputs[k] = v.cuda() |
| 40 | input_ids = inputs["input_ids"] |
| 41 | batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] |
| 42 | if generation_config is None: |
| 43 | generation_config = model.generation_config |
| 44 | generation_config = copy.deepcopy(generation_config) |
| 45 | model_kwargs = generation_config.update(**kwargs) |
| 46 | bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id |
| 47 | if isinstance(eos_token_id, int): |
| 48 | eos_token_id = [eos_token_id] |
| 49 | if additional_eos_token_id is not None: |
| 50 | eos_token_id.append(additional_eos_token_id) |
| 51 | has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None |
| 52 | if has_default_max_length and generation_config.max_new_tokens is None: |
| 53 | warnings.warn( |
| 54 | f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. " |
| 55 | "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we" |
| 56 | " recommend using `max_new_tokens` to control the maximum length of the generation.", |
| 57 | UserWarning, |
| 58 | ) |
| 59 | elif generation_config.max_new_tokens is not None: |
| 60 | generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length |
| 61 | if not has_default_max_length: |
| 62 | logger.warn( |
| 63 | f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(=" |
| 64 | f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. " |
| 65 | "Please refer to the documentation for more information. " |
| 66 | "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)", |
| 67 | UserWarning, |
| 68 | ) |
| 69 | |
| 70 | if input_ids_seq_length >= generation_config.max_length: |
| 71 | input_ids_string = "input_ids" |
| 72 | logger.warning( |
| 73 | f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to" |
| 74 | f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" |
| 75 | " increasing `max_new_tokens`." |
| 76 | ) |
| 77 | |
| 78 | # 2. Set generation parameters if not already defined |
| 79 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() |
| 80 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() |
| 81 | |
| 82 | logits_processor = model._get_logits_processor( |
no test coverage detected