RequestHandler is the core for handling existing requests and updating current batch. During generation process, we call schedule function each iteration to update current batch. Args: inference_config: Configuration for initialize and manage kv cache. model_config: Confi
| 138 | |
| 139 | |
| 140 | class RequestHandler(NaiveRequestHandler): |
| 141 | """ |
| 142 | RequestHandler is the core for handling existing requests and updating current batch. |
| 143 | During generation process, we call schedule function each iteration to update current batch. |
| 144 | |
| 145 | Args: |
| 146 | inference_config: Configuration for initialize and manage kv cache. |
| 147 | model_config: Configuration for model |
| 148 | dtype (torch.dtype): The data type for weights and activations. |
| 149 | """ |
| 150 | |
| 151 | def __init__(self, inference_config: InferenceConfig, model_config: PretrainedConfig) -> None: |
| 152 | self.inference_config = inference_config |
| 153 | self.running_list: RunningList = RunningList(inference_config.prefill_ratio) |
| 154 | self.waiting_list: List[List] = [[], [], []] |
| 155 | self.done_list: List[Sequence] = [] |
| 156 | self.dtype = inference_config.dtype |
| 157 | self.max_batch_size = inference_config.max_batch_size |
| 158 | |
| 159 | # initialize cache |
| 160 | self._init_cache(model_config) |
| 161 | |
| 162 | # initialize batch |
| 163 | device = torch.cuda.current_device() |
| 164 | kv_max_split_num = ( |
| 165 | inference_config.max_input_len + inference_config.max_output_len + inference_config.block_size - 1 |
| 166 | ) // inference_config.block_size |
| 167 | head_dim = model_config.hidden_size // model_config.num_attention_heads |
| 168 | |
| 169 | fd_inter_tensor = FDIntermTensors() |
| 170 | |
| 171 | if fd_inter_tensor._tensors_initialized: |
| 172 | fd_inter_tensor._reset() |
| 173 | |
| 174 | # For Spec-Dec, process the speculated tokens plus the token in the last step for each seq |
| 175 | max_n_tokens = self.max_batch_size |
| 176 | max_n_tokens *= self.inference_config.max_n_spec_tokens + 1 |
| 177 | |
| 178 | fd_inter_tensor.initialize( |
| 179 | max_batch_size=max_n_tokens, |
| 180 | num_attn_heads=model_config.num_attention_heads // inference_config.tp_size, |
| 181 | kv_max_split_num=kv_max_split_num, |
| 182 | head_dim=head_dim, |
| 183 | dtype=self.dtype, |
| 184 | device=device, |
| 185 | ) |
| 186 | |
| 187 | # TODO In the continuous batching scenario, the batch size may be greater than max_batch_size, |
| 188 | # which may cause bugs and this issue should be fixed later. |
| 189 | self.running_bb = BatchBucket( |
| 190 | num_heads=model_config.num_attention_heads // inference_config.tp_size, |
| 191 | head_dim=head_dim, |
| 192 | max_batch_size=self.max_batch_size, |
| 193 | max_length=inference_config.max_input_len + inference_config.max_output_len, |
| 194 | block_size=inference_config.block_size, |
| 195 | kv_max_split_num=kv_max_split_num, |
| 196 | fd_interm_tensor=fd_inter_tensor, |
| 197 | dtype=self.dtype, |
no outgoing calls
searching dependent graphs…