MCPcopy
hub / github.com/google-deepmind/gemma / init_cache

Method init_cache

gemma/gm/nn/gemma3n/_modules.py:419–438  ·  view source on GitHub ↗
(
      cls,
      cache_size: int,
      num_heads: int,
      head_dim: int,
      batch_size: int,
      dtype: jnp.dtype = jnp.bfloat16,
  )

Source from the content-addressed store, hash-verified

417
418 @classmethod
419 def init_cache(
420 cls,
421 cache_size: int,
422 num_heads: int,
423 head_dim: int,
424 batch_size: int,
425 dtype: jnp.dtype = jnp.bfloat16,
426 ) -> LayerCache:
427 del cls # not used
428 return {
429 'v': jnp.zeros(
430 (batch_size, cache_size, num_heads, head_dim), dtype=dtype
431 ),
432 'k': jnp.zeros(
433 (batch_size, cache_size, num_heads, head_dim), dtype=dtype
434 ),
435 'end_index': jnp.zeros((batch_size,), dtype=jnp.int32),
436 # Save the positions for the sliding window attention.
437 'positions': jnp.zeros((batch_size, cache_size), dtype=jnp.int32),
438 }
439
440
441class FeedForward(nn.Module):

Callers 8

_get_attn_outputFunction · 0.45
test_sliding_windowFunction · 0.45
test_blockFunction · 0.45
test_block_with_altupFunction · 0.45
test_block_with_laurelFunction · 0.45

Calls

no outgoing calls

Tested by 8

_get_attn_outputFunction · 0.36
test_sliding_windowFunction · 0.36
test_blockFunction · 0.36
test_block_with_altupFunction · 0.36
test_block_with_laurelFunction · 0.36