(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs=None,
)
| 150 | super().record_kwargs(layer_idx, d) |
| 151 | |
| 152 | def update( |
| 153 | self, |
| 154 | key_states: torch.Tensor, |
| 155 | value_states: torch.Tensor, |
| 156 | layer_idx: int, |
| 157 | cache_kwargs=None, |
| 158 | ) -> tuple[torch.Tensor, torch.Tensor]: |
| 159 | if key_states is not None and torch.is_tensor(key_states) and key_states.is_cuda: |
| 160 | key_states = key_states.cpu() |
| 161 | if value_states is not None and torch.is_tensor(value_states) and value_states.is_cuda: |
| 162 | value_states = value_states.cpu() |
| 163 | return super().update(key_states, value_states, layer_idx, cache_kwargs) |
| 164 | |
| 165 | def update_router_kcache( |
| 166 | self, |
no outgoing calls
no test coverage detected