(
model,
tokenizer,
params: Dict,
device: str,
context_len: int,
stream_interval: int = 2,
judge_sent_end: bool = False,
)
| 60 | |
| 61 | @torch.inference_mode() |
| 62 | def generate_stream( |
| 63 | model, |
| 64 | tokenizer, |
| 65 | params: Dict, |
| 66 | device: str, |
| 67 | context_len: int, |
| 68 | stream_interval: int = 2, |
| 69 | judge_sent_end: bool = False, |
| 70 | ): |
| 71 | if hasattr(model, "device"): |
| 72 | device = model.device |
| 73 | |
| 74 | # Read parameters |
| 75 | prompt = params["prompt"] |
| 76 | len_prompt = len(prompt) |
| 77 | temperature = float(params.get("temperature", 1.0)) |
| 78 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
| 79 | top_p = float(params.get("top_p", 1.0)) |
| 80 | top_k = int(params.get("top_k", -1)) # -1 means disable |
| 81 | max_new_tokens = int(params.get("max_new_tokens", 256)) |
| 82 | logprobs = params.get("logprobs", None) # FIXME: Support logprobs>1. |
| 83 | echo = bool(params.get("echo", True)) |
| 84 | stop_str = params.get("stop", None) |
| 85 | stop_token_ids = params.get("stop_token_ids", None) or [] |
| 86 | if tokenizer.eos_token_id not in stop_token_ids: |
| 87 | stop_token_ids.append(tokenizer.eos_token_id) |
| 88 | |
| 89 | logits_processor = prepare_logits_processor( |
| 90 | temperature, repetition_penalty, top_p, top_k |
| 91 | ) |
| 92 | input_ids = tokenizer(prompt).input_ids |
| 93 | |
| 94 | if model.config.is_encoder_decoder: |
| 95 | max_src_len = context_len |
| 96 | else: # truncate |
| 97 | max_src_len = context_len - max_new_tokens - 1 |
| 98 | |
| 99 | input_ids = input_ids[-max_src_len:] |
| 100 | output_ids = list(input_ids) |
| 101 | input_echo_len = len(input_ids) |
| 102 | |
| 103 | if model.config.is_encoder_decoder: |
| 104 | if logprobs is not None: # FIXME: Support logprobs for encoder-decoder models. |
| 105 | raise NotImplementedError |
| 106 | encoder_output = model.encoder( |
| 107 | input_ids=torch.as_tensor([input_ids], device=device) |
| 108 | )[0] |
| 109 | start_ids = torch.as_tensor( |
| 110 | [[model.generation_config.decoder_start_token_id]], |
| 111 | dtype=torch.int64, |
| 112 | device=device, |
| 113 | ) |
| 114 | else: |
| 115 | start_ids = torch.as_tensor([input_ids], device=device) |
| 116 | |
| 117 | past_key_values = out = None |
| 118 | token_logprobs = [None] # The first token has no logprobs. |
| 119 | sent_interrupt = False |
no test coverage detected
searching dependent graphs…