| 241 | |
| 242 | |
| 243 | def _trim_recent_cache(cache: List[Any], num_tokens: int) -> None: |
| 244 | if num_tokens <= 0: |
| 245 | return |
| 246 | for c in cache: |
| 247 | n = min(getattr(c, "offset", num_tokens), num_tokens) |
| 248 | if n <= 0: |
| 249 | continue |
| 250 | if isinstance(c, RotatingKVCache) and c.keys is not None: |
| 251 | c.keys = c._temporal_order(c.keys) |
| 252 | c.values = c._temporal_order(c.values) |
| 253 | c.keys = c.keys[..., :-n, :] |
| 254 | c.values = c.values[..., :-n, :] |
| 255 | c.offset -= n |
| 256 | c._idx = c.keys.shape[2] |
| 257 | elif hasattr(c, "trim"): |
| 258 | c.trim(n) |
| 259 | |
| 260 | |
| 261 | class _LayerHook: |