MCPcopy
hub / github.com/google-deepmind/penzai / test_embeddings

Method test_embeddings

tests/nn/embedding_test.py:24–45  ·  view source on GitHub ↗
(self)

Source from the content-addressed store, hash-verified

22class EmbeddingTest(absltest.TestCase):
23
24 def test_embeddings(self):
25 table = pz.nn.EmbeddingTable.from_config(
26 name="embeddings",
27 init_base_rng=jax.random.key(2),
28 vocab_size=17,
29 embedding_axes={"foo": 5, "bar": 7},
30 vocabulary_axis="vocabulary",
31 )
32
33 with self.subTest("lookup"):
34 lookup_layer = pz.nn.EmbeddingLookup(table)
35 result = lookup_layer(pz.nx.arange("baz", 3))
36 pz.chk.check_structure(
37 result, pz.chk.ArraySpec(named_shape={"foo": 5, "bar": 7, "baz": 3})
38 )
39
40 with self.subTest("decode"):
41 decode_layer = pz.nn.EmbeddingDecode(table)
42 result = decode_layer(pz.nx.ones({"foo": 5, "bar": 7, "baz": 3}))
43 pz.chk.check_structure(
44 result, pz.chk.ArraySpec(named_shape={"baz": 3, "vocabulary": 17})
45 )
46
47
48if __name__ == "__main__":

Callers

nothing calls this directly

Calls 1

from_configMethod · 0.45

Tested by

no test coverage detected