MCPcopy
hub / github.com/hpcaitech/ColossalAI / Drafter

Class Drafter

colossalai/inference/spec/drafter.py:13–123  ·  view source on GitHub ↗

Container for the Drafter Model (Assistant Model) used in Speculative Decoding. Args: model (nn.Module): The drafter model. tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model. device (torch.device): The device for the drafter model.

Source from the content-addressed store, hash-verified

11
12
13class Drafter:
14 """Container for the Drafter Model (Assistant Model) used in Speculative Decoding.
15
16 Args:
17 model (nn.Module): The drafter model.
18 tokenizer (transformers.PreTrainedTokenizer): The tokenizer for the drafter model.
19 device (torch.device): The device for the drafter model.
20 """
21
22 def __init__(
23 self,
24 model: nn.Module,
25 tokenizer: PreTrainedTokenizer,
26 device: torch.device = None,
27 dtype: torch.dtype = torch.float16,
28 ):
29 self._tokenizer = tokenizer
30 self._device = device or get_current_device()
31 self._dtype = dtype
32 self._drafter_model = model.to(self._device)
33 self._drafter_model = model.to(self._dtype)
34 self._drafter_model.eval()
35
36 def get_model(self) -> nn.Module:
37 return self._drafter_model
38
39 @staticmethod
40 def trim_kv_cache(
41 past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int
42 ) -> Tuple[Tuple[torch.FloatTensor]]:
43 """Trim the last `invalid_token_num` kv caches.
44
45 past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape
46 num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
47 invalid_token_num (int): The number of invalid tokens to trim.
48 """
49 if past_key_values is None or invalid_token_num < 1:
50 return past_key_values
51
52 trimmed_past_key_values = []
53 for layer_idx in range(len(past_key_values)):
54 past_key_value = past_key_values[layer_idx]
55 trimmed_past_key_values.append(
56 (
57 past_key_value[0][:, :, :-invalid_token_num, :],
58 past_key_value[1][:, :, :-invalid_token_num, :],
59 )
60 )
61 past_key_values = tuple(trimmed_past_key_values)
62 return past_key_values
63
64 @torch.inference_mode()
65 def speculate(
66 self,
67 input_ids: torch.Tensor,
68 n_spec_tokens: int,
69 past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
70 glide_input: Optional[GlideInput] = None,

Callers 3

enable_spec_decMethod · 0.90
test_drafterFunction · 0.90
test_spec_decFunction · 0.90

Calls

no outgoing calls

Tested by 2

test_drafterFunction · 0.72
test_spec_decFunction · 0.72

Used in the wild real call sites across dependent graphs

searching dependent graphs…