MCPcopy
hub / github.com/hpcaitech/ColossalAI / steps_spec_dec

Method steps_spec_dec

colossalai/inference/core/llm_engine.py:388–494  ·  view source on GitHub ↗

Run Speculative Decoding steps. This is like retrieving a single batch and launch inference with many steps of speculating by a drafter model as well as verifying by a main model. Returns: List[Sequence]: finished sequences generated by one step.

(self)

Source from the content-addressed store, hash-verified

386 self.use_spec_dec = False
387
388 def steps_spec_dec(self) -> List[Sequence]:
389 """
390 Run Speculative Decoding steps. This is like retrieving a single batch and launch inference
391 with many steps of speculating by a drafter model as well as verifying by a main model.
392
393 Returns:
394 List[Sequence]: finished sequences generated by one step.
395 """
396 batch = self.request_handler.schedule() # prefill batch
397 assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
398
399 input_token_ids, output_tensor, input_meta_data = self.prepare_input(batch)
400
401 if input_meta_data.use_cuda_graph:
402 model_executable = self.graph_runners[input_meta_data.batch_size]
403 else:
404 model_executable = self.model
405
406 # 1. Prefill small model (Drafter) - fill past kv cache for drafter model
407 # NOTE For glide drafter models, we won't actually apply glide during prefill stage
408 drafter_out = self.drafter.speculate(input_token_ids, 1, None)
409 next_token_ids_spec = drafter_out.next_tokens
410 drafter_past_key_values = drafter_out.past_key_values
411
412 # 2. Prefill main model (Verifier) - fill past kv cache for main model
413 logits = model_executable(input_token_ids, output_tensor, input_meta_data, self.k_cache, self.v_cache)
414 next_tokens = search_tokens(self.generation_config, logits, batch_token_ids=batch.batch_token_ids)
415 # append new inputs to the batch, temporarily
416 batch.append_batch_tokens(next_tokens)
417 self.request_handler.allocate_batch_spec_dec(batch, 1)
418 already_allocated_kv_len = batch.seq_lengths[0].item()
419 input_token_ids = batch.get_1D_inputs_spec_dec(1)
420
421 finished_sequences = self.request_handler.update()
422
423 while True:
424 # HACK Retrieve the running batch
425 # Using RequestHandler.schedule here will re-allocate same kv cache for the batch
426 batch = self.request_handler.running_bb # running batch
427 assert batch.current_batch_size == 1, "Only support bsz 1 for speculative decoding for now."
428
429 # 3. Decoding - Drafter model speculates `n` tokens
430 glide_input = None
431 if self.use_glide:
432 glide_input = GlideInput(
433 batch.get_block_table_tensor(),
434 self.k_cache[-1], # use kv cahces of the last layer
435 self.v_cache[-1],
436 batch.get_sequence_lengths(),
437 n_spec_tokens=self.n_spec_tokens,
438 )
439
440 drafter_out = self.drafter.speculate(
441 input_token_ids,
442 self.n_spec_tokens,
443 drafter_past_key_values,
444 glide_input=glide_input,
445 )

Callers 1

generateMethod · 0.95

Calls 15

prepare_inputMethod · 0.95
search_tokensFunction · 0.90
GlideInputClass · 0.90
speculateMethod · 0.80
append_batch_tokensMethod · 0.80
get_sequence_lengthsMethod · 0.80
append_next_tokensMethod · 0.80
set_use_spec_decMethod · 0.80
revoke_batch_tokensMethod · 0.80

Tested by

no test coverage detected