(
cls,
cache_size: int,
num_heads: int,
head_dim: int,
batch_size: int,
dtype: jnp.dtype = jnp.bfloat16,
)
| 417 | |
| 418 | @classmethod |
| 419 | def init_cache( |
| 420 | cls, |
| 421 | cache_size: int, |
| 422 | num_heads: int, |
| 423 | head_dim: int, |
| 424 | batch_size: int, |
| 425 | dtype: jnp.dtype = jnp.bfloat16, |
| 426 | ) -> LayerCache: |
| 427 | del cls # not used |
| 428 | return { |
| 429 | 'v': jnp.zeros( |
| 430 | (batch_size, cache_size, num_heads, head_dim), dtype=dtype |
| 431 | ), |
| 432 | 'k': jnp.zeros( |
| 433 | (batch_size, cache_size, num_heads, head_dim), dtype=dtype |
| 434 | ), |
| 435 | 'end_index': jnp.zeros((batch_size,), dtype=jnp.int32), |
| 436 | # Save the positions for the sliding window attention. |
| 437 | 'positions': jnp.zeros((batch_size, cache_size), dtype=jnp.int32), |
| 438 | } |
| 439 | |
| 440 | |
| 441 | class FeedForward(nn.Module): |
no outgoing calls