Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Parameters: key_states (`paddle.Tensor`): The new key states to cache. value_states (`paddle.Tensor`): The new value states to cache.
(
self,
key_states: paddle.Tensor,
value_states: paddle.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict[str, Any]] = None,
)
| 260 | self.layers[layer_idx].offload() |
| 261 | |
| 262 | def update( |
| 263 | self, |
| 264 | key_states: paddle.Tensor, |
| 265 | value_states: paddle.Tensor, |
| 266 | layer_idx: int, |
| 267 | cache_kwargs: Optional[dict[str, Any]] = None, |
| 268 | ) -> tuple[paddle.Tensor, paddle.Tensor]: |
| 269 | """ |
| 270 | Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. |
| 271 | |
| 272 | Parameters: |
| 273 | key_states (`paddle.Tensor`): |
| 274 | The new key states to cache. |
| 275 | value_states (`paddle.Tensor`): |
| 276 | The new value states to cache. |
| 277 | layer_idx (`int`): |
| 278 | The index of the layer to cache the states for. |
| 279 | cache_kwargs (`dict[str, Any]`, *optional*): |
| 280 | Additional arguments for the cache subclass. These are specific to each subclass and allow new types of |
| 281 | cache to be created. |
| 282 | |
| 283 | Return: |
| 284 | A tuple containing the updated key and value states. |
| 285 | """ |
| 286 | # In this case, the `layers` were not provided, and we must append as much as `layer_idx` |
| 287 | if self.layer_class_to_replicate is not None: |
| 288 | while len(self.layers) <= layer_idx: |
| 289 | self.layers.append(self.layer_class_to_replicate()) |
| 290 | |
| 291 | if self.offloading: |
| 292 | # Wait for the stream to finish if needed, and start prefetching the next layer |
| 293 | # Note: Since current_stream can't directly recognize key_states.place, |
| 294 | # we construct it as a string. However, this may cause unknown issues for other formats like xpu, |
| 295 | # so attention is needed. The directly returned place format is Place(gpu:0) |
| 296 | paddle.device.current_stream(f"gpu:{key_states.place.gpu_device_id()}").wait_stream(self.prefetch_stream) |
| 297 | self.prefetch(layer_idx + 1, self.only_non_sliding) |
| 298 | |
| 299 | keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs) |
| 300 | |
| 301 | if self.offloading: |
| 302 | self.offload(layer_idx, self.only_non_sliding) |
| 303 | |
| 304 | return keys, values |
| 305 | |
| 306 | def early_initialization( |
| 307 | self, batch_size: int, num_heads: int, head_dim: int, dtype: paddle.dtype, device: paddle.device |