(
self,
max_batch_size: int,
max_seq_len: int,
dtype: torch.dtype,
device: torch.device,
)
| 101 | self.args = args |
| 102 | |
| 103 | def init_kv_cache( |
| 104 | self, |
| 105 | max_batch_size: int, |
| 106 | max_seq_len: int, |
| 107 | dtype: torch.dtype, |
| 108 | device: torch.device, |
| 109 | ): |
| 110 | cache_shape = (max_batch_size, max_seq_len, self.n_kv_heads, self.head_dim) |
| 111 | cache_k = torch.zeros(cache_shape, dtype=dtype, device=device) |
| 112 | cache_v = torch.zeros(cache_shape, dtype=dtype, device=device) |
| 113 | self.register_buffer("cache_k", cache_k, persistent=False) |
| 114 | self.register_buffer("cache_v", cache_v, persistent=False) |
| 115 | |
| 116 | def del_kv_cache(self): |
| 117 | self.cache_k = None |
no outgoing calls
no test coverage detected