MCPcopy
hub / github.com/hpcaitech/ColossalAI / trim_kv_cache

Method trim_kv_cache

colossalai/inference/spec/drafter.py:40–62  ·  view source on GitHub ↗

Trim the last `invalid_token_num` kv caches. past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape num_layers x 2 x (bsz x num_heads x seq_len x head_dim) invalid_token_num (int): The number of invalid tokens to trim.

(
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int
    )

Source from the content-addressed store, hash-verified

38
39 @staticmethod
40 def trim_kv_cache(
41 past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int
42 ) -> Tuple[Tuple[torch.FloatTensor]]:
43 """Trim the last `invalid_token_num` kv caches.
44
45 past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape
46 num_layers x 2 x (bsz x num_heads x seq_len x head_dim)
47 invalid_token_num (int): The number of invalid tokens to trim.
48 """
49 if past_key_values is None or invalid_token_num < 1:
50 return past_key_values
51
52 trimmed_past_key_values = []
53 for layer_idx in range(len(past_key_values)):
54 past_key_value = past_key_values[layer_idx]
55 trimmed_past_key_values.append(
56 (
57 past_key_value[0][:, :, :-invalid_token_num, :],
58 past_key_value[1][:, :, :-invalid_token_num, :],
59 )
60 )
61 past_key_values = tuple(trimmed_past_key_values)
62 return past_key_values
63
64 @torch.inference_mode()
65 def speculate(

Callers 2

test_drafterFunction · 0.95
steps_spec_decMethod · 0.80

Calls 1

appendMethod · 0.45

Tested by 1

test_drafterFunction · 0.76