MCPcopy
hub / github.com/PaddlePaddle/PaddleFormers / update

Method update

paddleformers/transformers/cache_utils.py:262–304  ·  view source on GitHub ↗

Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. Parameters: key_states (`paddle.Tensor`): The new key states to cache. value_states (`paddle.Tensor`): The new value states to cache.

(
        self,
        key_states: paddle.Tensor,
        value_states: paddle.Tensor,
        layer_idx: int,
        cache_kwargs: Optional[dict[str, Any]] = None,
    )

Source from the content-addressed store, hash-verified

260 self.layers[layer_idx].offload()
261
262 def update(
263 self,
264 key_states: paddle.Tensor,
265 value_states: paddle.Tensor,
266 layer_idx: int,
267 cache_kwargs: Optional[dict[str, Any]] = None,
268 ) -> tuple[paddle.Tensor, paddle.Tensor]:
269 """
270 Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
271
272 Parameters:
273 key_states (`paddle.Tensor`):
274 The new key states to cache.
275 value_states (`paddle.Tensor`):
276 The new value states to cache.
277 layer_idx (`int`):
278 The index of the layer to cache the states for.
279 cache_kwargs (`dict[str, Any]`, *optional*):
280 Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
281 cache to be created.
282
283 Return:
284 A tuple containing the updated key and value states.
285 """
286 # In this case, the `layers` were not provided, and we must append as much as `layer_idx`
287 if self.layer_class_to_replicate is not None:
288 while len(self.layers) <= layer_idx:
289 self.layers.append(self.layer_class_to_replicate())
290
291 if self.offloading:
292 # Wait for the stream to finish if needed, and start prefetching the next layer
293 # Note: Since current_stream can't directly recognize key_states.place,
294 # we construct it as a string. However, this may cause unknown issues for other formats like xpu,
295 # so attention is needed. The directly returned place format is Place(gpu:0)
296 paddle.device.current_stream(f"gpu:{key_states.place.gpu_device_id()}").wait_stream(self.prefetch_stream)
297 self.prefetch(layer_idx + 1, self.only_non_sliding)
298
299 keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
300
301 if self.offloading:
302 self.offload(layer_idx, self.only_non_sliding)
303
304 return keys, values
305
306 def early_initialization(
307 self, batch_size: int, num_heads: int, head_dim: int, dtype: paddle.dtype, device: paddle.device

Callers 15

update_training_argsMethod · 0.45
merge_configsFunction · 0.45
start_local_trainers_cpuFunction · 0.45
start_local_trainersFunction · 0.45
load_test_configFunction · 0.45
update_paramsFunction · 0.45
init_dist_envMethod · 0.45
start_local_trainers_cpuFunction · 0.45
start_local_trainersFunction · 0.45

Calls 3

prefetchMethod · 0.95
offloadMethod · 0.95
appendMethod · 0.45