Run Speculative Decoding steps. This is like retrieving a single batch and launch inference with many steps of speculating by a drafter model as well as verifying by a main model. Returns: List[Sequence]: finished sequences generated by one step.
(self)
| 386 | self.use_spec_dec = False |
| 387 | |
| 388 | def steps_spec_dec(self) -> List[Sequence]: |
| 389 | """ |
| 390 | Run Speculative Decoding steps. This is like retrieving a single batch and launch inference |
| 391 | with many steps of speculating by a drafter model as well as verifying by a main model. |
| 392 | |
| 393 | Returns: |
| 394 | List[Sequence]: finished sequences generated by one step. |
| 395 | """ |
| 396 | batch = self.request_handler.schedule() # prefill batch |
| 397 | assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." |
| 398 | |
| 399 | input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch) |
| 400 | |
| 401 | if input_meta_data.use_cuda_graph: |
| 402 | model_executable = self.graph_runners[input_meta_data.batch_size] |
| 403 | else: |
| 404 | model_executable = self.model |
| 405 | |
| 406 | # 1. Prefill small model (Drafter) - fill past kv cache for drafter model |
| 407 | # NOTE For glide drafter models, we won't actually apply glide during prefill stage |
| 408 | drafter_out = self.drafter.speculate(input_token_ids, 1, None) |
| 409 | next_token_ids_spec = drafter_out.next_tokens |
| 410 | drafter_past_key_values = drafter_out.past_key_values |
| 411 | |
| 412 | # 2. Prefill main model (Verifier) - fill past kv cache for main model |
| 413 | logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache) |
| 414 | next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids) |
| 415 | # append new inputs to the batch, temporarily |
| 416 | batch.append_batch_tokens(next_tokens) |
| 417 | self.request_handler.allocate_batch_spec_dec(batch, 1) |
| 418 | already_allocated_kv_len = batch.seq_lengths[0].item() |
| 419 | input_token_ids = batch.get_1D_inputs_spec_dec(1) |
| 420 | |
| 421 | finished_sequences = self.request_handler.update() |
| 422 | |
| 423 | while True: |
| 424 | # HACK Retrieve the running batch |
| 425 | # Using RequestHandler.schedule here will re-allocate same kv cache for the batch |
| 426 | batch = self.request_handler.running_bb # running batch |
| 427 | assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now." |
| 428 | |
| 429 | # 3. Decoding - Drafter model speculates `n` tokens |
| 430 | glide_input = None |
| 431 | if self.use_glide: |
| 432 | glide_input = GlideInput( |
| 433 | batch.get_block_table_tensor(), |
| 434 | self.k_cache[-1], # use kv cahces of the last layer |
| 435 | self.v_cache[-1], |
| 436 | batch.get_sequence_lengths(), |
| 437 | n_spec_tokens=self.n_spec_tokens, |
| 438 | ) |
| 439 | |
| 440 | drafter_out = self.drafter.speculate( |
| 441 | input_token_ids, |
| 442 | self.n_spec_tokens, |
| 443 | drafter_past_key_values, |
| 444 | glide_input=glide_input, |
| 445 | ) |
no test coverage detected