MCPcopy
hub / github.com/lm-sys/FastChat / generate_stream

Function generate_stream

fastchat/serve/inference.py:62–316  ·  view source on GitHub ↗
(
    model,
    tokenizer,
    params: Dict,
    device: str,
    context_len: int,
    stream_interval: int = 2,
    judge_sent_end: bool = False,
)

Source from the content-addressed store, hash-verified

60
61@torch.inference_mode()
62def generate_stream(
63 model,
64 tokenizer,
65 params: Dict,
66 device: str,
67 context_len: int,
68 stream_interval: int = 2,
69 judge_sent_end: bool = False,
70):
71 if hasattr(model, "device"):
72 device = model.device
73
74 # Read parameters
75 prompt = params["prompt"]
76 len_prompt = len(prompt)
77 temperature = float(params.get("temperature", 1.0))
78 repetition_penalty = float(params.get("repetition_penalty", 1.0))
79 top_p = float(params.get("top_p", 1.0))
80 top_k = int(params.get("top_k", -1)) # -1 means disable
81 max_new_tokens = int(params.get("max_new_tokens", 256))
82 logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1.
83 echo = bool(params.get("echo", True))
84 stop_str = params.get("stop", None)
85 stop_token_ids = params.get("stop_token_ids", None) or []
86 if tokenizer.eos_token_id not in stop_token_ids:
87 stop_token_ids.append(tokenizer.eos_token_id)
88
89 logits_processor = prepare_logits_processor(
90 temperature, repetition_penalty, top_p, top_k
91 )
92 input_ids = tokenizer(prompt).input_ids
93
94 if model.config.is_encoder_decoder:
95 max_src_len = context_len
96 else: # truncate
97 max_src_len = context_len - max_new_tokens - 1
98
99 input_ids = input_ids[-max_src_len:]
100 output_ids = list(input_ids)
101 input_echo_len = len(input_ids)
102
103 if model.config.is_encoder_decoder:
104 if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models.
105 raise NotImplementedError
106 encoder_output = model.encoder(
107 input_ids=torch.as_tensor([input_ids], device=device)
108 )[0]
109 start_ids = torch.as_tensor(
110 [[model.generation_config.decoder_start_token_id]],
111 dtype=torch.int64,
112 device=device,
113 )
114 else:
115 start_ids = torch.as_tensor([input_ids], device=device)
116
117 past_key_values = out = None
118 token_logprobs = [None] # The first token has no logprobs.
119 sent_interrupt = False

Callers 1

generateMethod · 0.90

Calls 4

is_sentence_completeFunction · 0.90
is_partial_stopFunction · 0.90
prepare_logits_processorFunction · 0.85
toMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…