(
model,
tokenizer,
params,
device,
context_len=2048,
stream_interval=2,
judge_sent_end=False,
)
| 12 | |
| 13 | @torch.inference_mode() |
| 14 | def generate_stream_codet5p( |
| 15 | model, |
| 16 | tokenizer, |
| 17 | params, |
| 18 | device, |
| 19 | context_len=2048, |
| 20 | stream_interval=2, |
| 21 | judge_sent_end=False, |
| 22 | ): |
| 23 | prompt = params["prompt"] |
| 24 | temperature = float(params.get("temperature", 1.0)) |
| 25 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
| 26 | top_p = float(params.get("top_p", 1.0)) |
| 27 | top_k = int(params.get("top_k", 50)) # -1 means disable |
| 28 | max_new_tokens = int(params.get("max_new_tokens", 1024)) |
| 29 | stop_token_ids = params.get("stop_token_ids", None) or [] |
| 30 | stop_token_ids.append(tokenizer.eos_token_id) |
| 31 | |
| 32 | decode_config = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) |
| 33 | streamer = TextIteratorStreamer(tokenizer, **decode_config) |
| 34 | encoding = tokenizer(prompt, return_tensors="pt").to(device) |
| 35 | input_ids = encoding.input_ids |
| 36 | encoding["decoder_input_ids"] = encoding["input_ids"].clone() |
| 37 | input_echo_len = len(input_ids) |
| 38 | |
| 39 | generation_config = GenerationConfig( |
| 40 | max_new_tokens=max_new_tokens, |
| 41 | do_sample=temperature >= 1e-5, |
| 42 | temperature=temperature, |
| 43 | repetition_penalty=repetition_penalty, |
| 44 | no_repeat_ngram_size=10, |
| 45 | top_p=top_p, |
| 46 | top_k=top_k, |
| 47 | eos_token_id=stop_token_ids, |
| 48 | ) |
| 49 | |
| 50 | class CodeBlockStopper(StoppingCriteria): |
| 51 | def __call__( |
| 52 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs |
| 53 | ) -> bool: |
| 54 | # Code-completion is open-end generation. |
| 55 | # We check \n\n to stop at end of a code block. |
| 56 | if list(input_ids[0][-2:]) == [628, 198]: |
| 57 | return True |
| 58 | return False |
| 59 | |
| 60 | gen_kwargs = dict( |
| 61 | **encoding, |
| 62 | streamer=streamer, |
| 63 | generation_config=generation_config, |
| 64 | stopping_criteria=StoppingCriteriaList([CodeBlockStopper()]), |
| 65 | ) |
| 66 | thread = Thread(target=model.generate, kwargs=gen_kwargs) |
| 67 | thread.start() |
| 68 | i = 0 |
| 69 | output = "" |
| 70 | for new_text in streamer: |
| 71 | i += 1 |
nothing calls this directly
no test coverage detected
searching dependent graphs…