Trim the last `invalid_token_num` kv caches. past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape num_layers x 2 x (bsz x num_heads x seq_len x head_dim) invalid_token_num (int): The number of invalid tokens to trim.
(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int
)
| 38 | |
| 39 | @staticmethod |
| 40 | def trim_kv_cache( |
| 41 | past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]], invalid_token_num: int |
| 42 | ) -> Tuple[Tuple[torch.FloatTensor]]: |
| 43 | """Trim the last `invalid_token_num` kv caches. |
| 44 | |
| 45 | past_key_values (Tuple[Tuple[torch.FloatTensor]]): The past key values with shape |
| 46 | num_layers x 2 x (bsz x num_heads x seq_len x head_dim) |
| 47 | invalid_token_num (int): The number of invalid tokens to trim. |
| 48 | """ |
| 49 | if past_key_values is None or invalid_token_num < 1: |
| 50 | return past_key_values |
| 51 | |
| 52 | trimmed_past_key_values = [] |
| 53 | for layer_idx in range(len(past_key_values)): |
| 54 | past_key_value = past_key_values[layer_idx] |
| 55 | trimmed_past_key_values.append( |
| 56 | ( |
| 57 | past_key_value[0][:, :, :-invalid_token_num, :], |
| 58 | past_key_value[1][:, :, :-invalid_token_num, :], |
| 59 | ) |
| 60 | ) |
| 61 | past_key_values = tuple(trimmed_past_key_values) |
| 62 | return past_key_values |
| 63 | |
| 64 | @torch.inference_mode() |
| 65 | def speculate( |