MCPcopy
hub / github.com/google-deepmind/gemma / test_sliding_window

Function test_sliding_window

gemma/gm/nn/gemma3n/_modules_test.py:322–370  ·  view source on GitHub ↗
()

Source from the content-addressed store, hash-verified

320
321
322def test_sliding_window():
323 num_heads = 2
324 head_dim = 4
325 features = 8
326 cache_size = 7
327 batch_size = 2
328 query_pre_attn_scalar = head_dim**-0.5
329 seq_len = 6
330
331 attn = _modules.Attention(
332 num_heads=num_heads,
333 num_kv_heads=num_heads,
334 features=features,
335 head_dim=head_dim,
336 attn_type=_ATTN_TYPE,
337 query_pre_attn_scalar=query_pre_attn_scalar,
338 )
339
340 rng = jax.random.PRNGKey(0)
341 x = jnp.ones((batch_size, seq_len, features))
342 segment_pos = (
343 jnp.repeat(jnp.arange(seq_len), batch_size).reshape(seq_len, batch_size).T
344 )
345 cache = _modules.Attention.init_cache(
346 cache_size=cache_size,
347 num_heads=num_heads,
348 head_dim=head_dim,
349 batch_size=batch_size,
350 dtype=jnp.float32,
351 )
352 attn_mask = jnp.ones((batch_size, seq_len, cache_size))
353
354 params = attn.init(rng, x, segment_pos, cache, attn_mask)
355 _, output = attn.apply(params, x, segment_pos, cache, attn_mask)
356
357 sliding_attn = _modules.Attention(
358 num_heads=num_heads,
359 num_kv_heads=num_heads,
360 features=features,
361 head_dim=head_dim,
362 attn_type=_modules.AttentionType.LOCAL_SLIDING,
363 sliding_window_size=2,
364 query_pre_attn_scalar=query_pre_attn_scalar,
365 )
366 _, sliding_output = sliding_attn.apply(
367 params, x, segment_pos, cache, attn_mask
368 )
369
370 assert not (output == sliding_output).all()
371
372
373def test_query_pre_attn_scalar_modifies_output():

Callers

nothing calls this directly

Calls 3

initMethod · 0.80
applyMethod · 0.80
init_cacheMethod · 0.45

Tested by

no test coverage detected