| 44 | |
| 45 | @torch.inference_mode() |
| 46 | def generate_interactive( |
| 47 | model, |
| 48 | tokenizer, |
| 49 | prompt, |
| 50 | generation_config: Optional[GenerationConfig] = None, |
| 51 | logits_processor: Optional[LogitsProcessorList] = None, |
| 52 | stopping_criteria: Optional[StoppingCriteriaList] = None, |
| 53 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], |
| 54 | List[int]]] = None, |
| 55 | additional_eos_token_id: Optional[int] = None, |
| 56 | **kwargs, |
| 57 | ): |
| 58 | inputs = tokenizer([prompt], padding=True, return_tensors='pt') |
| 59 | input_length = len(inputs['input_ids'][0]) |
| 60 | for k, v in inputs.items(): |
| 61 | inputs[k] = v.cuda() |
| 62 | input_ids = inputs['input_ids'] |
| 63 | _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1] |
| 64 | if generation_config is None: |
| 65 | generation_config = model.generation_config |
| 66 | generation_config = copy.deepcopy(generation_config) |
| 67 | model_kwargs = generation_config.update(**kwargs) |
| 68 | bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612 |
| 69 | generation_config.bos_token_id, |
| 70 | generation_config.eos_token_id, |
| 71 | ) |
| 72 | if isinstance(eos_token_id, int): |
| 73 | eos_token_id = [eos_token_id] |
| 74 | if additional_eos_token_id is not None: |
| 75 | eos_token_id.append(additional_eos_token_id) |
| 76 | has_default_max_length = kwargs.get( |
| 77 | 'max_length') is None and generation_config.max_length is not None |
| 78 | if has_default_max_length and generation_config.max_new_tokens is None: |
| 79 | warnings.warn( |
| 80 | f"Using 'max_length''s default \ |
| 81 | ({repr(generation_config.max_length)}) \ |
| 82 | to control the generation length. " |
| 83 | 'This behaviour is deprecated and will be removed from the \ |
| 84 | config in v5 of Transformers -- we' |
| 85 | ' recommend using `max_new_tokens` to control the maximum \ |
| 86 | length of the generation.', |
| 87 | UserWarning, |
| 88 | ) |
| 89 | elif generation_config.max_new_tokens is not None: |
| 90 | generation_config.max_length = generation_config.max_new_tokens + \ |
| 91 | input_ids_seq_length |
| 92 | if not has_default_max_length: |
| 93 | logger.warn( # pylint: disable=W4902 |
| 94 | f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) " |
| 95 | f"and 'max_length'(={generation_config.max_length}) seem to " |
| 96 | "have been set. 'max_new_tokens' will take precedence. " |
| 97 | 'Please refer to the documentation for more information. ' |
| 98 | '(https://huggingface.co/docs/transformers/main/' |
| 99 | 'en/main_classes/text_generation)', |
| 100 | UserWarning, |
| 101 | ) |
| 102 | |
| 103 | if input_ids_seq_length >= generation_config.max_length: |