MCPcopy
hub / github.com/lm-sys/FastChat / generate_stream_codet5p

Function generate_stream_codet5p

fastchat/model/model_codet5p.py:14–108  ·  view source on GitHub ↗
(
    model,
    tokenizer,
    params,
    device,
    context_len=2048,
    stream_interval=2,
    judge_sent_end=False,
)

Source from the content-addressed store, hash-verified

12
13@torch.inference_mode()
14def generate_stream_codet5p(
15 model,
16 tokenizer,
17 params,
18 device,
19 context_len=2048,
20 stream_interval=2,
21 judge_sent_end=False,
22):
23 prompt = params["prompt"]
24 temperature = float(params.get("temperature", 1.0))
25 repetition_penalty = float(params.get("repetition_penalty", 1.0))
26 top_p = float(params.get("top_p", 1.0))
27 top_k = int(params.get("top_k", 50)) # -1 means disable
28 max_new_tokens = int(params.get("max_new_tokens", 1024))
29 stop_token_ids = params.get("stop_token_ids", None) or []
30 stop_token_ids.append(tokenizer.eos_token_id)
31
32 decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True)
33 streamer = TextIteratorStreamer(tokenizer, **decode_config)
34 encoding = tokenizer(prompt, return_tensors="pt").to(device)
35 input_ids = encoding.input_ids
36 encoding["decoder_input_ids"] = encoding["input_ids"].clone()
37 input_echo_len = len(input_ids)
38
39 generation_config = GenerationConfig(
40 max_new_tokens=max_new_tokens,
41 do_sample=temperature >= 1e-5,
42 temperature=temperature,
43 repetition_penalty=repetition_penalty,
44 no_repeat_ngram_size=10,
45 top_p=top_p,
46 top_k=top_k,
47 eos_token_id=stop_token_ids,
48 )
49
50 class CodeBlockStopper(StoppingCriteria):
51 def __call__(
52 self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs
53 ) -> bool:
54 # Code-completion is open-end generation.
55 # We check \n\n to stop at end of a code block.
56 if list(input_ids[0][-2:]) == [628, 198]:
57 return True
58 return False
59
60 gen_kwargs = dict(
61 **encoding,
62 streamer=streamer,
63 generation_config=generation_config,
64 stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]),
65 )
66 thread = Thread(target=model.generate, kwargs=gen_kwargs)
67 thread.start()
68 i = 0
69 output = ""
70 for new_text in streamer:
71 i += 1

Callers

nothing calls this directly

Calls 2

CodeBlockStopperClass · 0.85
toMethod · 0.80

Tested by

no test coverage detected

Used in the wild real call sites across dependent graphs

searching dependent graphs…