MCPcopy
hub / github.com/InternLM/Tutorial / generate_interactive

Function generate_interactive

tools/streamlit_demo.py:46–175  ·  view source on GitHub ↗
(
    model,
    tokenizer,
    prompt,
    generation_config: Optional[GenerationConfig] = None,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
                                                List[int]]] = None,
    additional_eos_token_id: Optional[int] = None,
    **kwargs,
)

Source from the content-addressed store, hash-verified

44
45@torch.inference_mode()
46def generate_interactive(
47 model,
48 tokenizer,
49 prompt,
50 generation_config: Optional[GenerationConfig] = None,
51 logits_processor: Optional[LogitsProcessorList] = None,
52 stopping_criteria: Optional[StoppingCriteriaList] = None,
53 prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
54 List[int]]] = None,
55 additional_eos_token_id: Optional[int] = None,
56 **kwargs,
57):
58 inputs = tokenizer([prompt], padding=True, return_tensors='pt')
59 input_length = len(inputs['input_ids'][0])
60 for k, v in inputs.items():
61 inputs[k] = v.cuda()
62 input_ids = inputs['input_ids']
63 _, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
64 if generation_config is None:
65 generation_config = model.generation_config
66 generation_config = copy.deepcopy(generation_config)
67 model_kwargs = generation_config.update(**kwargs)
68 bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
69 generation_config.bos_token_id,
70 generation_config.eos_token_id,
71 )
72 if isinstance(eos_token_id, int):
73 eos_token_id = [eos_token_id]
74 if additional_eos_token_id is not None:
75 eos_token_id.append(additional_eos_token_id)
76 has_default_max_length = kwargs.get(
77 'max_length') is None and generation_config.max_length is not None
78 if has_default_max_length and generation_config.max_new_tokens is None:
79 warnings.warn(
80 f"Using 'max_length''s default \
81 ({repr(generation_config.max_length)}) \
82 to control the generation length. "
83 'This behaviour is deprecated and will be removed from the \
84 config in v5 of Transformers -- we'
85 ' recommend using `max_new_tokens` to control the maximum \
86 length of the generation.',
87 UserWarning,
88 )
89 elif generation_config.max_new_tokens is not None:
90 generation_config.max_length = generation_config.max_new_tokens + \
91 input_ids_seq_length
92 if not has_default_max_length:
93 logger.warn( # pylint: disable=W4902
94 f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
95 f"and 'max_length'(={generation_config.max_length}) seem to "
96 "have been set. 'max_new_tokens' will take precedence. "
97 'Please refer to the documentation for more information. '
98 '(https://huggingface.co/docs/transformers/main/'
99 'en/main_classes/text_generation)',
100 UserWarning,
101 )
102
103 if input_ids_seq_length >= generation_config.max_length:

Callers 1

mainFunction · 0.70

Calls

no outgoing calls

Tested by

no test coverage detected