()
| 320 | |
| 321 | |
| 322 | def test_sliding_window(): |
| 323 | num_heads = 2 |
| 324 | head_dim = 4 |
| 325 | features = 8 |
| 326 | cache_size = 7 |
| 327 | batch_size = 2 |
| 328 | query_pre_attn_scalar = head_dim**-0.5 |
| 329 | seq_len = 6 |
| 330 | |
| 331 | attn = _modules.Attention( |
| 332 | num_heads=num_heads, |
| 333 | num_kv_heads=num_heads, |
| 334 | features=features, |
| 335 | head_dim=head_dim, |
| 336 | attn_type=_ATTN_TYPE, |
| 337 | query_pre_attn_scalar=query_pre_attn_scalar, |
| 338 | ) |
| 339 | |
| 340 | rng = jax.random.PRNGKey(0) |
| 341 | x = jnp.ones((batch_size, seq_len, features)) |
| 342 | segment_pos = ( |
| 343 | jnp.repeat(jnp.arange(seq_len), batch_size).reshape(seq_len, batch_size).T |
| 344 | ) |
| 345 | cache = _modules.Attention.init_cache( |
| 346 | cache_size=cache_size, |
| 347 | num_heads=num_heads, |
| 348 | head_dim=head_dim, |
| 349 | batch_size=batch_size, |
| 350 | dtype=jnp.float32, |
| 351 | ) |
| 352 | attn_mask = jnp.ones((batch_size, seq_len, cache_size)) |
| 353 | |
| 354 | params = attn.init(rng, x, segment_pos, cache, attn_mask) |
| 355 | _, output = attn.apply(params, x, segment_pos, cache, attn_mask) |
| 356 | |
| 357 | sliding_attn = _modules.Attention( |
| 358 | num_heads=num_heads, |
| 359 | num_kv_heads=num_heads, |
| 360 | features=features, |
| 361 | head_dim=head_dim, |
| 362 | attn_type=_modules.AttentionType.LOCAL_SLIDING, |
| 363 | sliding_window_size=2, |
| 364 | query_pre_attn_scalar=query_pre_attn_scalar, |
| 365 | ) |
| 366 | _, sliding_output = sliding_attn.apply( |
| 367 | params, x, segment_pos, cache, attn_mask |
| 368 | ) |
| 369 | |
| 370 | assert not (output == sliding_output).all() |
| 371 | |
| 372 | |
| 373 | def test_query_pre_attn_scalar_modifies_output(): |
nothing calls this directly
no test coverage detected