(
self,
max_batch_size: int,
max_seq_len: int,
device: torch.device,
dtype: torch.dtype,
)
| 288 | return output |
| 289 | |
| 290 | def init_kv_cache( |
| 291 | self, |
| 292 | max_batch_size: int, |
| 293 | max_seq_len: int, |
| 294 | device: torch.device, |
| 295 | dtype: torch.dtype, |
| 296 | ): |
| 297 | for layer in self.layers: |
| 298 | layer.self_attn.init_kv_cache( |
| 299 | max_batch_size, max_seq_len, dtype=dtype, device=device |
| 300 | ) |
| 301 | |
| 302 | def del_kv_cache(self): |
| 303 | for layer in self.layers: |
nothing calls this directly
no test coverage detected