MCPcopy
hub / github.com/InternLM/InternLM / generate_interactive

Function generate_interactive

tools/transformers/interface.py:25–137  ·  view source on GitHub ↗
(
    model, 
    tokenizer,
    prompt,
    generation_config: Optional[GenerationConfig] = None,
    logits_processor: Optional[LogitsProcessorList] = None,
    stopping_criteria: Optional[StoppingCriteriaList] = None,
    prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
    additional_eos_token_id: Optional[int] = None,
    **kwargs,
)

Source from the content-addressed store, hash-verified

23
24@torch.inference_mode()
25def generate_interactive(
26 model,
27 tokenizer,
28 prompt,
29 generation_config: Optional[GenerationConfig] = None,
30 logits_processor: Optional[LogitsProcessorList] = None,
31 stopping_criteria: Optional[StoppingCriteriaList] = None,
32 prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
33 additional_eos_token_id: Optional[int] = None,
34 **kwargs,
35):
36 inputs = tokenizer([prompt], padding=True, return_tensors="pt")
37 input_length = len(inputs["input_ids"][0])
38 for k, v in inputs.items():
39 inputs[k] = v.cuda()
40 input_ids = inputs["input_ids"]
41 batch_size, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
42 if generation_config is None:
43 generation_config = model.generation_config
44 generation_config = copy.deepcopy(generation_config)
45 model_kwargs = generation_config.update(**kwargs)
46 bos_token_id, eos_token_id = generation_config.bos_token_id, generation_config.eos_token_id
47 if isinstance(eos_token_id, int):
48 eos_token_id = [eos_token_id]
49 if additional_eos_token_id is not None:
50 eos_token_id.append(additional_eos_token_id)
51 has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
52 if has_default_max_length and generation_config.max_new_tokens is None:
53 warnings.warn(
54 f"Using `max_length`'s default ({generation_config.max_length}) to control the generation length. "
55 "This behaviour is deprecated and will be removed from the config in v5 of Transformers -- we"
56 " recommend using `max_new_tokens` to control the maximum length of the generation.",
57 UserWarning,
58 )
59 elif generation_config.max_new_tokens is not None:
60 generation_config.max_length = generation_config.max_new_tokens + input_ids_seq_length
61 if not has_default_max_length:
62 logger.warn(
63 f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
64 f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
65 "Please refer to the documentation for more information. "
66 "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)",
67 UserWarning,
68 )
69
70 if input_ids_seq_length >= generation_config.max_length:
71 input_ids_string = "input_ids"
72 logger.warning(
73 f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
74 f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
75 " increasing `max_new_tokens`."
76 )
77
78 # 2. Set generation parameters if not already defined
79 logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
80 stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
81
82 logits_processor = model._get_logits_processor(

Callers 2

mainFunction · 0.90
generateMethod · 0.90

Calls 2

updateMethod · 0.45

Tested by

no test coverage detected